From 6cd0fa18dc39f139d1636afae2a70e714b0eb60c Mon Sep 17 00:00:00 2001
From: Rodrigo Lazo Paz <rlazo@google.com>
Date: Tue, 19 Mar 2024 17:17:40 -0400
Subject: [PATCH] Add basic testing

---
 .../client/generativeai/common/util/tests.kt  |   2 -
 generativeai/build.gradle.kts                 |   5 +-
 .../ai/client/generativeai/GenerativeModel.kt |   8 +-
 .../generativeai/GenerativeModelTests.kt      | 112 +++++++++++++++++-
 4 files changed, 114 insertions(+), 13 deletions(-)

diff --git a/common/src/test/java/com/google/ai/client/generativeai/common/util/tests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/util/tests.kt
index 2655d086..d753a9bf 100644
--- a/common/src/test/java/com/google/ai/client/generativeai/common/util/tests.kt
+++ b/common/src/test/java/com/google/ai/client/generativeai/common/util/tests.kt
@@ -18,8 +18,6 @@
 
 package com.google.ai.client.generativeai.common.util
 
-// import com.google.ai.client.generativeai.internal.util.send
-// import com.google.ai.client.generativeai.type.RequestOptions
 import com.google.ai.client.generativeai.common.APIController
 import com.google.ai.client.generativeai.common.GenerateContentRequest
 import com.google.ai.client.generativeai.common.GenerateContentResponse
diff --git a/generativeai/build.gradle.kts b/generativeai/build.gradle.kts
index 592ae54f..bac2aefa 100644
--- a/generativeai/build.gradle.kts
+++ b/generativeai/build.gradle.kts
@@ -85,8 +85,9 @@ dependencies {
     implementation("androidx.concurrent:concurrent-futures:1.2.0-alpha02")
     implementation("androidx.concurrent:concurrent-futures-ktx:1.2.0-alpha02")
     testImplementation("junit:junit:4.13.2")
-    testImplementation("io.kotest:kotest-assertions-core:4.0.7")
-    testImplementation("io.kotest:kotest-assertions-jvm:4.0.7")
+    testImplementation("io.kotest:kotest-assertions-core:5.5.5")
+    testImplementation("io.kotest:kotest-assertions-core-jvm:5.5.5")
+    testImplementation("io.mockk:mockk:1.13.10")
     androidTestImplementation("androidx.test.ext:junit:1.1.5")
     androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")
 
diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt
index 2d572659..690f5be4 100644
--- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt
+++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt
@@ -71,11 +71,7 @@ internal constructor(
     generationConfig,
     safetySettings,
     requestOptions,
-    APIController(
-      apiKey,
-      modelName,
-      requestOptions.toInternal()
-    )
+    APIController(apiKey, modelName, requestOptions.toInternal())
   )
 
   /**
@@ -101,8 +97,8 @@ internal constructor(
   fun generateContentStream(vararg prompt: Content): Flow<GenerateContentResponse> =
     controller
       .generateContentStream(constructRequest(*prompt))
-      .map { it.toPublic().validate() }
       .catch { throw GoogleGenerativeAIException.from(it) }
+      .map { it.toPublic().validate() }
 
   /**
    * Generates a response from the backend with the provided text represented [Content].
diff --git a/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt b/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt
index f2d888c7..bfc0b24c 100644
--- a/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt
+++ b/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt
@@ -16,13 +16,119 @@
 
 package com.google.ai.client.generativeai
 
-import io.kotest.matchers.shouldBe
+import com.google.ai.client.generativeai.common.APIController
+import com.google.ai.client.generativeai.common.GenerateContentRequest as GenerateContentRequest_Common
+import com.google.ai.client.generativeai.common.GenerateContentResponse as GenerateContentResponse_Common
+import com.google.ai.client.generativeai.common.InvalidAPIKeyException as InvalidAPIKeyException_Common
+import com.google.ai.client.generativeai.common.UnsupportedUserLocationException as UnsupportedUserLocationException_Common
+import com.google.ai.client.generativeai.common.server.Candidate as Candidate_Common
+import com.google.ai.client.generativeai.common.shared.Content as Content_Common
+import com.google.ai.client.generativeai.common.shared.TextPart as TextPart_Common
+import com.google.ai.client.generativeai.type.Candidate
+import com.google.ai.client.generativeai.type.Content
+import com.google.ai.client.generativeai.type.GenerateContentResponse
+import com.google.ai.client.generativeai.type.InvalidAPIKeyException
+import com.google.ai.client.generativeai.type.PromptFeedback
+import com.google.ai.client.generativeai.type.TextPart
+import com.google.ai.client.generativeai.type.UnsupportedUserLocationException
+import io.kotest.assertions.throwables.shouldThrow
+import io.kotest.matchers.collections.shouldHaveSize
+import io.kotest.matchers.equality.shouldBeEqualToUsingFields
+import io.mockk.coEvery
+import io.mockk.mockk
+import kotlinx.coroutines.CoroutineScope
+import kotlinx.coroutines.flow.flow
+import kotlinx.coroutines.runBlocking
 import org.junit.Test
 
 internal class GenerativeModelTests {
 
+  private val apiKey: String = "api_key"
+  private val mockApiController = mockk<APIController>()
+
+  @Test
+  fun `generateContent request succeeds`() = doBlocking {
+    val model = GenerativeModel("gemini-pro-1.0", apiKey, controller = mockApiController)
+    coEvery {
+      mockApiController.generateContent(
+        GenerateContentRequest_Common(
+          "gemini-pro-1.0",
+          contents = listOf(Content_Common(parts = listOf(TextPart_Common("Why's the sky blue?"))))
+        )
+      )
+    } returns
+      GenerateContentResponse_Common(
+        listOf(
+          Candidate_Common(
+            content =
+              Content_Common(
+                parts = listOf(TextPart_Common("I'm still learning how to answer this question"))
+              ),
+            finishReason = null,
+            safetyRatings = listOf(),
+            citationMetadata = null
+          )
+        )
+      )
+
+    val expectedResponse =
+      GenerateContentResponse(
+        listOf(
+          Candidate(
+            Content(parts = listOf(TextPart("I'm still learning how to answer this question"))),
+            safetyRatings = listOf(),
+            citationMetadata = listOf(),
+            finishReason = null
+          )
+        ),
+        PromptFeedback(null, listOf())
+      )
+
+    val response = model.generateContent("Why's the sky blue?")
+
+    response.shouldBeEqualToUsingFields(expectedResponse, GenerateContentResponse::text)
+    response.candidates shouldHaveSize expectedResponse.candidates.size
+    response.candidates[0].shouldBeEqualToUsingFields(
+      expectedResponse.candidates[0],
+      Candidate::finishReason,
+      Candidate::citationMetadata,
+      Candidate::safetyRatings
+    )
+  }
+
   @Test
-  fun `TODO`() {
-    1 shouldBe 1
+  fun `generateContent throws exception`() = doBlocking {
+    val model = GenerativeModel("gemini-pro-1.0", apiKey, controller = mockApiController)
+    coEvery {
+      mockApiController.generateContent(
+        GenerateContentRequest_Common(
+          "gemini-pro-1.0",
+          contents = listOf(Content_Common(parts = listOf(TextPart_Common("Why's the sky blue?"))))
+        )
+      )
+    } throws InvalidAPIKeyException_Common("exception message")
+
+    shouldThrow<InvalidAPIKeyException> { model.generateContent("Why's the sky blue?") }
   }
+
+  @Test
+  fun `generateContentStream throws exception`() = doBlocking {
+    val model = GenerativeModel("gemini-pro-1.0", apiKey, controller = mockApiController)
+    coEvery {
+      mockApiController.generateContentStream(
+        GenerateContentRequest_Common(
+          "gemini-pro-1.0",
+          contents = listOf(Content_Common(parts = listOf(TextPart_Common("Why's the sky blue?"))))
+        )
+      )
+    } returns flow { throw UnsupportedUserLocationException_Common() }
+
+    shouldThrow<UnsupportedUserLocationException> {
+      model.generateContentStream("Why's the sky blue?").collect {}
+    }
+  }
+}
+
+internal fun doBlocking(block: suspend CoroutineScope.() -> Unit) {
+  runBlocking(block = block)
 }