From b89694563253c151a5d2f2bbc0a649f5c2aafef0 Mon Sep 17 00:00:00 2001
From: Tharsanan1 <tharsanan.15@cse.mrt.ac.lk>
Date: Thu, 3 Oct 2024 12:48:35 +0530
Subject: [PATCH 1/2] Add AI analytics integration tests

---
 .../publishers/dto/AITokenUsage.java          | 18 +++++-----
 .../analytics/ChoreoAnalyticsProvider.java    | 28 ++++++++++------
 .../wso2/apk/integration/api/BaseSteps.java   | 33 +++++++++++++++++++
 .../api/APIBackendBasedAIRatelimit.feature    | 15 +++++++++
 4 files changed, 75 insertions(+), 19 deletions(-)

diff --git a/gateway/enforcer/org.wso2.apk.enforcer.commons/src/main/java/org/wso2/apk/enforcer/commons/analytics/publishers/dto/AITokenUsage.java b/gateway/enforcer/org.wso2.apk.enforcer.commons/src/main/java/org/wso2/apk/enforcer/commons/analytics/publishers/dto/AITokenUsage.java
index b381fd271..5607fd8bc 100644
--- a/gateway/enforcer/org.wso2.apk.enforcer.commons/src/main/java/org/wso2/apk/enforcer/commons/analytics/publishers/dto/AITokenUsage.java
+++ b/gateway/enforcer/org.wso2.apk.enforcer.commons/src/main/java/org/wso2/apk/enforcer/commons/analytics/publishers/dto/AITokenUsage.java
@@ -24,18 +24,18 @@
  */
 public class AITokenUsage {
     @JsonProperty("totalTokens")
-    private Double totalTokens;
+    private Integer totalTokens;
 
     @JsonProperty("promptTokens")
-    private Double promptTokens;
+    private Integer promptTokens;
 
     @JsonProperty("completionTokens")
-    private Double completionTokens;
+    private Integer completionTokens;
 
     @JsonProperty("hour")
     private Integer hour;
 
-    public Double getTotalTokens() {
+    public Integer getTotalTokens() {
 
         return totalTokens;
     }
@@ -50,27 +50,27 @@ public void setHour(Integer hour) {
         this.hour = hour;
     }
 
-    public void setTotalTokens(Double totalTokens) {
+    public void setTotalTokens(Integer totalTokens) {
 
         this.totalTokens = totalTokens;
     }
 
-    public Double getPromptTokens() {
+    public Integer getPromptTokens() {
 
         return promptTokens;
     }
 
-    public void setPromptTokens(Double promptTokens) {
+    public void setPromptTokens(Integer promptTokens) {
 
         this.promptTokens = promptTokens;
     }
 
-    public Double getCompletionTokens() {
+    public Integer getCompletionTokens() {
 
         return completionTokens;
     }
 
-    public void setCompletionTokens(Double completionTokens) {
+    public void setCompletionTokens(Integer completionTokens) {
 
         this.completionTokens = completionTokens;
     }
diff --git a/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/analytics/ChoreoAnalyticsProvider.java b/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/analytics/ChoreoAnalyticsProvider.java
index 40784b448..bf92c8927 100644
--- a/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/analytics/ChoreoAnalyticsProvider.java
+++ b/gateway/enforcer/org.wso2.apk.enforcer/src/main/java/org/wso2/apk/enforcer/analytics/ChoreoAnalyticsProvider.java
@@ -23,7 +23,6 @@
 import com.google.protobuf.Value;
 import io.envoyproxy.envoy.data.accesslog.v3.AccessLogCommon;
 import io.envoyproxy.envoy.data.accesslog.v3.HTTPAccessLogEntry;
-import io.envoyproxy.envoy.service.ext_proc.v3.ProcessingResponse;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.wso2.apk.enforcer.commons.analytics.collectors.AnalyticsCustomDataProvider;
@@ -260,9 +259,9 @@ public Map<String, Object> getProperties() {
         Map<String,Object> map = new HashMap();
         Map<String, Value> fieldsMap = getFieldsMapFromLogEntry();
         String gwURL = getValueAsString(fieldsMap, MetadataConstants.GATEWAY_URL);
-        Double totalTokenCount = getValueAsDouble(fieldsMap, MetadataConstants.TOTAL_TOKEN_COUNT);
-        Double completionTokenCount = getValueAsDouble(fieldsMap, MetadataConstants.COMPLETION_TOKEN_COUNT);
-        Double promptTokenCount = getValueAsDouble(fieldsMap, MetadataConstants.PROMPT_TOKEN_COUNT);
+        Integer totalTokenCount = getValueAsInteger(fieldsMap, MetadataConstants.TOTAL_TOKEN_COUNT);
+        Integer completionTokenCount = getValueAsInteger(fieldsMap, MetadataConstants.COMPLETION_TOKEN_COUNT);
+        Integer promptTokenCount = getValueAsInteger(fieldsMap, MetadataConstants.PROMPT_TOKEN_COUNT);
         String model = getValueAsString(fieldsMap, MetadataConstants.MODEL);
         String providerName = getValueAsString(fieldsMap, MetadataConstants.AI_PROVIDER_NAME);
         String providerAPIVersion = getValueAsString(fieldsMap, MetadataConstants.AI_PROVIDER_API_VERSION);
@@ -308,12 +307,13 @@ private String getValueAsString(Map<String, Value> fieldsMap, String key) {
         return fieldsMap.get(key).getStringValue();
     }
 
-    private Double getValueAsDouble(Map<String, Value> fieldsMap, String key) {
+    private Integer getValueAsInteger(Map<String, Value> fieldsMap, String key) {
 
         if (fieldsMap == null || !fieldsMap.containsKey(key)) {
             return null;
         }
-        return fieldsMap.get(key).getNumberValue();
+        Double d = fieldsMap.get(key).getNumberValue();
+        return d.intValue();
     }
 
     private Map<String, Value> getFieldsMapFromLogEntry() {
@@ -325,10 +325,18 @@ private Map<String, Value> getFieldsMapFromLogEntry() {
                 .containsKey(MetadataConstants.EXT_AUTH_METADATA_CONTEXT_KEY)) {
             return new HashMap<>(0);
         }
-        Map<String, Value> metadataFromExtProc = logEntry.getCommonProperties().getMetadata().getFilterMetadataMap()
-                .get(MetadataConstants.EXT_PROC_METADATA_CONTEXT_KEY).getFieldsMap();
-        Map<String, Value> metadataFromExtAuthz = logEntry.getCommonProperties().getMetadata().getFilterMetadataMap()
-                .get(MetadataConstants.EXT_AUTH_METADATA_CONTEXT_KEY).getFieldsMap();
+        Map<String, Value> metadataFromExtProc = new HashMap<>();
+        if (logEntry.getCommonProperties().getMetadata().getFilterMetadataMap()
+                .get(MetadataConstants.EXT_PROC_METADATA_CONTEXT_KEY) != null) {
+            metadataFromExtProc = logEntry.getCommonProperties().getMetadata().getFilterMetadataMap()
+                    .get(MetadataConstants.EXT_PROC_METADATA_CONTEXT_KEY).getFieldsMap();
+        }
+        Map<String, Value> metadataFromExtAuthz = new HashMap<>();
+        if (logEntry.getCommonProperties().getMetadata().getFilterMetadataMap()
+                .get(MetadataConstants.EXT_AUTH_METADATA_CONTEXT_KEY) != null) {
+            metadataFromExtAuthz = logEntry.getCommonProperties().getMetadata().getFilterMetadataMap()
+                    .get(MetadataConstants.EXT_AUTH_METADATA_CONTEXT_KEY).getFieldsMap();
+        }
         Map<String, Value> mergedMetadata = new HashMap<>(metadataFromExtProc);
         mergedMetadata.putAll(metadataFromExtAuthz);
         return mergedMetadata;
diff --git a/test/cucumber-tests/src/test/java/org/wso2/apk/integration/api/BaseSteps.java b/test/cucumber-tests/src/test/java/org/wso2/apk/integration/api/BaseSteps.java
index 9d0d09795..e7245f198 100644
--- a/test/cucumber-tests/src/test/java/org/wso2/apk/integration/api/BaseSteps.java
+++ b/test/cucumber-tests/src/test/java/org/wso2/apk/integration/api/BaseSteps.java
@@ -39,6 +39,12 @@
 import io.cucumber.java.Before;
 import io.cucumber.java.en.Given;
 import io.cucumber.java.en.Then;
+import io.kubernetes.client.openapi.ApiClient;
+import io.kubernetes.client.openapi.ApiException;
+import io.kubernetes.client.openapi.Configuration;
+import io.kubernetes.client.openapi.apis.CoreV1Api;
+import io.kubernetes.client.openapi.models.V1Pod;
+import io.kubernetes.client.util.Config;
 import org.apache.commons.io.IOUtils;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
@@ -71,6 +77,8 @@
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
+
 import io.grpc.Status;
 import io.grpc.StatusRuntimeException;
 
@@ -305,6 +313,31 @@ public void eventualSuccess(int statusCode, DataTable dataTable) throws IOExcept
         }
     }
 
+    @Then("I see following strings in the enforcer logs")
+    public void checkEnforcerLogs(DataTable dataTable) throws IOException, InterruptedException, ApiException {
+        List<String> stringsToCheck = dataTable.asList(String.class);
+        ApiClient client = Config.defaultClient();
+        Configuration.setDefaultApiClient(client);
+        CoreV1Api api = new CoreV1Api();
+        String namespace = "apk-integration-test";
+        String podName = "your-pod-name";
+        String labelSelector = "app.kubernetes.io/app=gateway";
+
+        List<V1Pod> podList = api.listNamespacedPod(namespace).labelSelector(labelSelector).execute().getItems();
+        if (!podList.isEmpty()) {
+            podName = Objects.requireNonNull(podList.get(0).getMetadata()).getName();
+        }
+        try {
+            String logs = api.readNamespacedPodLog(podName, namespace).container("enforcer").sinceSeconds(60).execute();
+            Assert.assertNotNull(logs, String.format("Could not find any logs in the last 60 seconds. PodName: %s, namespace: %s", podName, namespace));
+            for(String word : stringsToCheck) {
+                Assert.assertTrue(logs.contains(word), "Expected word '" + word + "' not found in logs");
+            }
+        } catch (ApiException e) {
+            System.out.println(e);
+        }
+    }
+
     @Then("I wait for next minute")
     public void waitForNextMinute() throws InterruptedException {
         LocalDateTime now = LocalDateTime.now();
diff --git a/test/cucumber-tests/src/test/resources/tests/api/APIBackendBasedAIRatelimit.feature b/test/cucumber-tests/src/test/resources/tests/api/APIBackendBasedAIRatelimit.feature
index 13fb3c024..5c18d0e09 100644
--- a/test/cucumber-tests/src/test/resources/tests/api/APIBackendBasedAIRatelimit.feature
+++ b/test/cucumber-tests/src/test/resources/tests/api/APIBackendBasedAIRatelimit.feature
@@ -10,6 +10,21 @@ Feature: API backend based AI ratelimit Feature
     Then the response status code should be 200
     And the response headers should contain
       | x-ratelimit-remaining | 4999 |
+    Then I see following strings in the enforcer logs
+      |aiMetadata|
+      |gpt-35-turbo|
+      |AzureAI|
+      |2024-06-01|
+      |aiTokenUsage|
+      |1000|
+      |300|
+      |500|
+      |hour|
+      |vendor_name|
+      |vendor_version|
+      |totalTokens|
+      |promptTokens|
+      |completionTokens|
     And I wait for 3 seconds
     And I send "GET" request to "https://default.gw.wso2.com:9095/llm-api/v1.0.0/3.14/employee?send=body" with body ""
     Then the response status code should be 200

From bc720d9552aa19df39b24512f9a751f7d3c8d2e6 Mon Sep 17 00:00:00 2001
From: Tharsanan1 <tharsanan.15@cse.mrt.ac.lk>
Date: Mon, 7 Oct 2024 09:44:02 +0530
Subject: [PATCH 2/2] Add dependency

---
 test/cucumber-tests/build.gradle | 1 +
 1 file changed, 1 insertion(+)

diff --git a/test/cucumber-tests/build.gradle b/test/cucumber-tests/build.gradle
index d7e3821d3..1c35a066b 100644
--- a/test/cucumber-tests/build.gradle
+++ b/test/cucumber-tests/build.gradle
@@ -51,6 +51,7 @@ dependencies {
     implementation 'io.grpc:grpc-stub:1.57.0'
     implementation 'io.grpc:grpc-stub:1.57.0'
     implementation 'com.google.protobuf:protobuf-java:4.28.2'
+    implementation group: 'io.kubernetes', name: 'client-java', version: '21.0.1'
 }
 
 test {