Skip to content

Commit

Permalink
Add unit tests to common (#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
rlazo authored Mar 18, 2024
1 parent b8857dd commit 7507710
Show file tree
Hide file tree
Showing 36 changed files with 1,471 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.google.ai.client.generativeai.common

import com.google.ai.client.generativeai.common.server.FinishReason
import com.google.ai.client.generativeai.common.util.decodeToFlow
import io.ktor.client.HttpClient
import io.ktor.client.call.body
Expand All @@ -37,7 +38,9 @@ import io.ktor.http.contentType
import io.ktor.serialization.kotlinx.json.json
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.channelFlow
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.timeout
import kotlinx.coroutines.launch
import kotlinx.serialization.json.Json
Expand Down Expand Up @@ -79,28 +82,39 @@ class APIController(
}

suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse =
client
.post("$DOMAIN/${requestOptions.apiVersion}/$model:generateContent") {
applyCommonConfiguration(request)
}
.also { validateResponse(it) }
.body()

fun generateContentStream(request: GenerateContentRequest): Flow<GenerateContentResponse> {
return client.postStream<GenerateContentResponse>(
"$DOMAIN/${requestOptions.apiVersion}/$model:streamGenerateContent?alt=sse"
) {
applyCommonConfiguration(request)
try {
client
.post("$DOMAIN/${requestOptions.apiVersion}/$model:generateContent") {
applyCommonConfiguration(request)
}
.also { validateResponse(it) }
.body<GenerateContentResponse>()
.validate()
} catch (e: Throwable) {
throw GoogleGenerativeAIException.from(e)
}
}

suspend fun countTokens(request: CountTokensRequest): CountTokensResponse =
fun generateContentStream(request: GenerateContentRequest): Flow<GenerateContentResponse> =
client
.post("$DOMAIN/${requestOptions.apiVersion}/$model:countTokens") {
.postStream<GenerateContentResponse>(
"$DOMAIN/${requestOptions.apiVersion}/$model:streamGenerateContent?alt=sse"
) {
applyCommonConfiguration(request)
}
.also { validateResponse(it) }
.body()
.map { it.validate() }
.catch { throw GoogleGenerativeAIException.from(it) }

suspend fun countTokens(request: CountTokensRequest): CountTokensResponse =
try {
client
.post("$DOMAIN/${requestOptions.apiVersion}/$model:countTokens") {
applyCommonConfiguration(request)
}
.also { validateResponse(it) }
.body()
} catch (e: Throwable) {
throw GoogleGenerativeAIException.from(e)
}

private fun HttpRequestBuilder.applyCommonConfiguration(request: Request) {
when (request) {
Expand Down Expand Up @@ -165,21 +179,31 @@ private inline fun <reified R : Response> HttpClient.postStream(
}

private suspend fun validateResponse(response: HttpResponse) {
if (response.status != HttpStatusCode.OK) {
val text = response.bodyAsText()
val message =
try {
JSON.decodeFromString<GRpcErrorResponse>(text).error.message
} catch (e: Throwable) {
"Unexpected Response:\n$text"
}
if (message.contains("API key not valid")) {
throw InvalidAPIKeyException(message)
if (response.status == HttpStatusCode.OK) return
val text = response.bodyAsText()
val message =
try {
JSON.decodeFromString<GRpcErrorResponse>(text).error.message
} catch (e: Throwable) {
"Unexpected Response:\n$text"
}
// TODO (b/325117891): Use a better method than string matching.
if (message == "User location is not supported for the API use.") {
throw UnsupportedUserLocationException()
}
throw ServerException(message)
if (message.contains("API key not valid")) {
throw InvalidAPIKeyException(message)
}
// TODO (b/325117891): Use a better method than string matching.
if (message == "User location is not supported for the API use.") {
throw UnsupportedUserLocationException()
}
throw ServerException(message)
}

private fun GenerateContentResponse.validate() = apply {
if ((candidates?.isEmpty() != false) && promptFeedback == null) {
throw SerializationException("Error deserializing response, found no valid fields")
}
promptFeedback?.blockReason?.let { throw PromptBlockedException(this) }
candidates
?.mapNotNull { it.finishReason }
?.firstOrNull { it != FinishReason.STOP }
?.let { throw ResponseStoppedException(this) }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.ai.client.generativeai.common

import com.google.ai.client.generativeai.common.shared.Content
import com.google.ai.client.generativeai.common.shared.TextPart
import com.google.ai.client.generativeai.common.util.commonTest
import com.google.ai.client.generativeai.common.util.createResponses
import com.google.ai.client.generativeai.common.util.doBlocking
import com.google.ai.client.generativeai.common.util.prepareStreamingResponse
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.shouldBe
import io.kotest.matchers.string.shouldContain
import io.ktor.client.engine.mock.MockEngine
import io.ktor.client.engine.mock.respond
import io.ktor.http.HttpHeaders
import io.ktor.http.HttpStatusCode
import io.ktor.http.headersOf
import io.ktor.utils.io.ByteChannel
import io.ktor.utils.io.close
import io.ktor.utils.io.writeFully
import kotlin.time.Duration.Companion.seconds
import kotlinx.coroutines.withTimeout
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized

internal class GenerativeModelTests {
private val testTimeout = 5.seconds

@Test
fun `(generateContentStream) emits responses as they come in`() = commonTest {
val response = createResponses("The", " world", " is", " a", " beautiful", " place!")
val bytes = prepareStreamingResponse(response)

bytes.forEach { channel.writeFully(it) }
val responses = apiController.generateContentStream(textGenerateContentRequest("test"))

withTimeout(testTimeout) {
responses.collect {
it.candidates?.isEmpty() shouldBe false
channel.close()
}
}
}

@Test
fun `(generateContent) respects a custom timeout`() =
commonTest(requestOptions = RequestOptions(2.seconds)) {
shouldThrow<RequestTimeoutException> {
withTimeout(testTimeout) {
apiController.generateContent(textGenerateContentRequest("test"))
}
}
}
}

@RunWith(Parameterized::class)
internal class ModelNamingTests(private val modelName: String, private val actualName: String) {

@Test
fun `request should include right model name`() = doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
val controller = APIController("super_cool_test_key", modelName, RequestOptions(), mockEngine)

withTimeout(5.seconds) {
controller.generateContentStream(textGenerateContentRequest("cats")).collect {
it.candidates?.isEmpty() shouldBe false
channel.close()
}
}

mockEngine.requestHistory.first().url.encodedPath shouldContain actualName
}

companion object {
@JvmStatic
@Parameterized.Parameters
fun data() =
listOf(
arrayOf("gemini-pro", "models/gemini-pro"),
arrayOf("x/gemini-pro", "x/gemini-pro"),
arrayOf("models/gemini-pro", "models/gemini-pro"),
arrayOf("/modelname", "/modelname"),
arrayOf("modifiedNaming/mymodel", "modifiedNaming/mymodel"),
)
}
}

fun textGenerateContentRequest(prompt: String) =
GenerateContentRequest(
model = "unused",
contents = listOf(Content(parts = listOf(TextPart(prompt))))
)
Loading

0 comments on commit 7507710

Please sign in to comment.