diff --git a/common/src/test/java/com/google/ai/client/generativeai/common/util/tests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/util/tests.kt index 2655d086..d753a9bf 100644 --- a/common/src/test/java/com/google/ai/client/generativeai/common/util/tests.kt +++ b/common/src/test/java/com/google/ai/client/generativeai/common/util/tests.kt @@ -18,8 +18,6 @@ package com.google.ai.client.generativeai.common.util -// import com.google.ai.client.generativeai.internal.util.send -// import com.google.ai.client.generativeai.type.RequestOptions import com.google.ai.client.generativeai.common.APIController import com.google.ai.client.generativeai.common.GenerateContentRequest import com.google.ai.client.generativeai.common.GenerateContentResponse diff --git a/generativeai/build.gradle.kts b/generativeai/build.gradle.kts index 592ae54f..bac2aefa 100644 --- a/generativeai/build.gradle.kts +++ b/generativeai/build.gradle.kts @@ -85,8 +85,9 @@ dependencies { implementation("androidx.concurrent:concurrent-futures:1.2.0-alpha02") implementation("androidx.concurrent:concurrent-futures-ktx:1.2.0-alpha02") testImplementation("junit:junit:4.13.2") - testImplementation("io.kotest:kotest-assertions-core:4.0.7") - testImplementation("io.kotest:kotest-assertions-jvm:4.0.7") + testImplementation("io.kotest:kotest-assertions-core:5.5.5") + testImplementation("io.kotest:kotest-assertions-core-jvm:5.5.5") + testImplementation("io.mockk:mockk:1.13.10") androidTestImplementation("androidx.test.ext:junit:1.1.5") androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1") diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt index 2d572659..690f5be4 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt @@ -71,11 +71,7 @@ internal constructor( generationConfig, safetySettings, requestOptions, - APIController( - apiKey, - modelName, - requestOptions.toInternal() - ) + APIController(apiKey, modelName, requestOptions.toInternal()) ) /** @@ -101,8 +97,8 @@ internal constructor( fun generateContentStream(vararg prompt: Content): Flow = controller .generateContentStream(constructRequest(*prompt)) - .map { it.toPublic().validate() } .catch { throw GoogleGenerativeAIException.from(it) } + .map { it.toPublic().validate() } /** * Generates a response from the backend with the provided text represented [Content]. diff --git a/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt b/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt index f2d888c7..bfc0b24c 100644 --- a/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt +++ b/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt @@ -16,13 +16,119 @@ package com.google.ai.client.generativeai -import io.kotest.matchers.shouldBe +import com.google.ai.client.generativeai.common.APIController +import com.google.ai.client.generativeai.common.GenerateContentRequest as GenerateContentRequest_Common +import com.google.ai.client.generativeai.common.GenerateContentResponse as GenerateContentResponse_Common +import com.google.ai.client.generativeai.common.InvalidAPIKeyException as InvalidAPIKeyException_Common +import com.google.ai.client.generativeai.common.UnsupportedUserLocationException as UnsupportedUserLocationException_Common +import com.google.ai.client.generativeai.common.server.Candidate as Candidate_Common +import com.google.ai.client.generativeai.common.shared.Content as Content_Common +import com.google.ai.client.generativeai.common.shared.TextPart as TextPart_Common +import com.google.ai.client.generativeai.type.Candidate +import com.google.ai.client.generativeai.type.Content +import com.google.ai.client.generativeai.type.GenerateContentResponse +import com.google.ai.client.generativeai.type.InvalidAPIKeyException +import com.google.ai.client.generativeai.type.PromptFeedback +import com.google.ai.client.generativeai.type.TextPart +import com.google.ai.client.generativeai.type.UnsupportedUserLocationException +import io.kotest.assertions.throwables.shouldThrow +import io.kotest.matchers.collections.shouldHaveSize +import io.kotest.matchers.equality.shouldBeEqualToUsingFields +import io.mockk.coEvery +import io.mockk.mockk +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.runBlocking import org.junit.Test internal class GenerativeModelTests { + private val apiKey: String = "api_key" + private val mockApiController = mockk() + + @Test + fun `generateContent request succeeds`() = doBlocking { + val model = GenerativeModel("gemini-pro-1.0", apiKey, controller = mockApiController) + coEvery { + mockApiController.generateContent( + GenerateContentRequest_Common( + "gemini-pro-1.0", + contents = listOf(Content_Common(parts = listOf(TextPart_Common("Why's the sky blue?")))) + ) + ) + } returns + GenerateContentResponse_Common( + listOf( + Candidate_Common( + content = + Content_Common( + parts = listOf(TextPart_Common("I'm still learning how to answer this question")) + ), + finishReason = null, + safetyRatings = listOf(), + citationMetadata = null + ) + ) + ) + + val expectedResponse = + GenerateContentResponse( + listOf( + Candidate( + Content(parts = listOf(TextPart("I'm still learning how to answer this question"))), + safetyRatings = listOf(), + citationMetadata = listOf(), + finishReason = null + ) + ), + PromptFeedback(null, listOf()) + ) + + val response = model.generateContent("Why's the sky blue?") + + response.shouldBeEqualToUsingFields(expectedResponse, GenerateContentResponse::text) + response.candidates shouldHaveSize expectedResponse.candidates.size + response.candidates[0].shouldBeEqualToUsingFields( + expectedResponse.candidates[0], + Candidate::finishReason, + Candidate::citationMetadata, + Candidate::safetyRatings + ) + } + @Test - fun `TODO`() { - 1 shouldBe 1 + fun `generateContent throws exception`() = doBlocking { + val model = GenerativeModel("gemini-pro-1.0", apiKey, controller = mockApiController) + coEvery { + mockApiController.generateContent( + GenerateContentRequest_Common( + "gemini-pro-1.0", + contents = listOf(Content_Common(parts = listOf(TextPart_Common("Why's the sky blue?")))) + ) + ) + } throws InvalidAPIKeyException_Common("exception message") + + shouldThrow { model.generateContent("Why's the sky blue?") } } + + @Test + fun `generateContentStream throws exception`() = doBlocking { + val model = GenerativeModel("gemini-pro-1.0", apiKey, controller = mockApiController) + coEvery { + mockApiController.generateContentStream( + GenerateContentRequest_Common( + "gemini-pro-1.0", + contents = listOf(Content_Common(parts = listOf(TextPart_Common("Why's the sky blue?")))) + ) + ) + } returns flow { throw UnsupportedUserLocationException_Common() } + + shouldThrow { + model.generateContentStream("Why's the sky blue?").collect {} + } + } +} + +internal fun doBlocking(block: suspend CoroutineScope.() -> Unit) { + runBlocking(block = block) }