Skip to content

Commit

Permalink
Add basic testing
Browse files Browse the repository at this point in the history
  • Loading branch information
rlazo committed Mar 19, 2024
1 parent 77a7e28 commit 6cd0fa1
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

package com.google.ai.client.generativeai.common.util

// import com.google.ai.client.generativeai.internal.util.send
// import com.google.ai.client.generativeai.type.RequestOptions
import com.google.ai.client.generativeai.common.APIController
import com.google.ai.client.generativeai.common.GenerateContentRequest
import com.google.ai.client.generativeai.common.GenerateContentResponse
Expand Down
5 changes: 3 additions & 2 deletions generativeai/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ dependencies {
implementation("androidx.concurrent:concurrent-futures:1.2.0-alpha02")
implementation("androidx.concurrent:concurrent-futures-ktx:1.2.0-alpha02")
testImplementation("junit:junit:4.13.2")
testImplementation("io.kotest:kotest-assertions-core:4.0.7")
testImplementation("io.kotest:kotest-assertions-jvm:4.0.7")
testImplementation("io.kotest:kotest-assertions-core:5.5.5")
testImplementation("io.kotest:kotest-assertions-core-jvm:5.5.5")
testImplementation("io.mockk:mockk:1.13.10")
androidTestImplementation("androidx.test.ext:junit:1.1.5")
androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,7 @@ internal constructor(
generationConfig,
safetySettings,
requestOptions,
APIController(
apiKey,
modelName,
requestOptions.toInternal()
)
APIController(apiKey, modelName, requestOptions.toInternal())
)

/**
Expand All @@ -101,8 +97,8 @@ internal constructor(
fun generateContentStream(vararg prompt: Content): Flow<GenerateContentResponse> =
controller
.generateContentStream(constructRequest(*prompt))
.map { it.toPublic().validate() }
.catch { throw GoogleGenerativeAIException.from(it) }
.map { it.toPublic().validate() }

/**
* Generates a response from the backend with the provided text represented [Content].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,119 @@

package com.google.ai.client.generativeai

import io.kotest.matchers.shouldBe
import com.google.ai.client.generativeai.common.APIController
import com.google.ai.client.generativeai.common.GenerateContentRequest as GenerateContentRequest_Common
import com.google.ai.client.generativeai.common.GenerateContentResponse as GenerateContentResponse_Common
import com.google.ai.client.generativeai.common.InvalidAPIKeyException as InvalidAPIKeyException_Common
import com.google.ai.client.generativeai.common.UnsupportedUserLocationException as UnsupportedUserLocationException_Common
import com.google.ai.client.generativeai.common.server.Candidate as Candidate_Common
import com.google.ai.client.generativeai.common.shared.Content as Content_Common
import com.google.ai.client.generativeai.common.shared.TextPart as TextPart_Common
import com.google.ai.client.generativeai.type.Candidate
import com.google.ai.client.generativeai.type.Content
import com.google.ai.client.generativeai.type.GenerateContentResponse
import com.google.ai.client.generativeai.type.InvalidAPIKeyException
import com.google.ai.client.generativeai.type.PromptFeedback
import com.google.ai.client.generativeai.type.TextPart
import com.google.ai.client.generativeai.type.UnsupportedUserLocationException
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.collections.shouldHaveSize
import io.kotest.matchers.equality.shouldBeEqualToUsingFields
import io.mockk.coEvery
import io.mockk.mockk
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.runBlocking
import org.junit.Test

internal class GenerativeModelTests {

private val apiKey: String = "api_key"
private val mockApiController = mockk<APIController>()

@Test
fun `generateContent request succeeds`() = doBlocking {
val model = GenerativeModel("gemini-pro-1.0", apiKey, controller = mockApiController)
coEvery {
mockApiController.generateContent(
GenerateContentRequest_Common(
"gemini-pro-1.0",
contents = listOf(Content_Common(parts = listOf(TextPart_Common("Why's the sky blue?"))))
)
)
} returns
GenerateContentResponse_Common(
listOf(
Candidate_Common(
content =
Content_Common(
parts = listOf(TextPart_Common("I'm still learning how to answer this question"))
),
finishReason = null,
safetyRatings = listOf(),
citationMetadata = null
)
)
)

val expectedResponse =
GenerateContentResponse(
listOf(
Candidate(
Content(parts = listOf(TextPart("I'm still learning how to answer this question"))),
safetyRatings = listOf(),
citationMetadata = listOf(),
finishReason = null
)
),
PromptFeedback(null, listOf())
)

val response = model.generateContent("Why's the sky blue?")

response.shouldBeEqualToUsingFields(expectedResponse, GenerateContentResponse::text)
response.candidates shouldHaveSize expectedResponse.candidates.size
response.candidates[0].shouldBeEqualToUsingFields(
expectedResponse.candidates[0],
Candidate::finishReason,
Candidate::citationMetadata,
Candidate::safetyRatings
)
}

@Test
fun `TODO`() {
1 shouldBe 1
fun `generateContent throws exception`() = doBlocking {
val model = GenerativeModel("gemini-pro-1.0", apiKey, controller = mockApiController)
coEvery {
mockApiController.generateContent(
GenerateContentRequest_Common(
"gemini-pro-1.0",
contents = listOf(Content_Common(parts = listOf(TextPart_Common("Why's the sky blue?"))))
)
)
} throws InvalidAPIKeyException_Common("exception message")

shouldThrow<InvalidAPIKeyException> { model.generateContent("Why's the sky blue?") }
}

@Test
fun `generateContentStream throws exception`() = doBlocking {
val model = GenerativeModel("gemini-pro-1.0", apiKey, controller = mockApiController)
coEvery {
mockApiController.generateContentStream(
GenerateContentRequest_Common(
"gemini-pro-1.0",
contents = listOf(Content_Common(parts = listOf(TextPart_Common("Why's the sky blue?"))))
)
)
} returns flow { throw UnsupportedUserLocationException_Common() }

shouldThrow<UnsupportedUserLocationException> {
model.generateContentStream("Why's the sky blue?").collect {}
}
}
}

internal fun doBlocking(block: suspend CoroutineScope.() -> Unit) {
runBlocking(block = block)
}

0 comments on commit 6cd0fa1

Please sign in to comment.