Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit tests to common #85

Merged
merged 7 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.google.ai.client.generativeai.common

import com.google.ai.client.generativeai.common.server.FinishReason
import com.google.ai.client.generativeai.common.util.decodeToFlow
import io.ktor.client.HttpClient
import io.ktor.client.call.body
Expand All @@ -37,7 +38,9 @@ import io.ktor.http.contentType
import io.ktor.serialization.kotlinx.json.json
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.channelFlow
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.timeout
import kotlinx.coroutines.launch
import kotlinx.serialization.json.Json
Expand Down Expand Up @@ -79,28 +82,39 @@ class APIController(
}

suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse =
client
.post("$DOMAIN/${requestOptions.apiVersion}/$model:generateContent") {
applyCommonConfiguration(request)
}
.also { validateResponse(it) }
.body()

fun generateContentStream(request: GenerateContentRequest): Flow<GenerateContentResponse> {
return client.postStream<GenerateContentResponse>(
"$DOMAIN/${requestOptions.apiVersion}/$model:streamGenerateContent?alt=sse"
) {
applyCommonConfiguration(request)
try {
client
.post("$DOMAIN/${requestOptions.apiVersion}/$model:generateContent") {
applyCommonConfiguration(request)
}
.also { validateResponse(it) }
.body<GenerateContentResponse>()
.validate()
} catch (e: Throwable) {
throw GoogleGenerativeAIException.from(e)
}
}

suspend fun countTokens(request: CountTokensRequest): CountTokensResponse =
fun generateContentStream(request: GenerateContentRequest): Flow<GenerateContentResponse> =
client
.post("$DOMAIN/${requestOptions.apiVersion}/$model:countTokens") {
.postStream<GenerateContentResponse>(
"$DOMAIN/${requestOptions.apiVersion}/$model:streamGenerateContent?alt=sse"
daymxn marked this conversation as resolved.
Show resolved Hide resolved
) {
applyCommonConfiguration(request)
}
.also { validateResponse(it) }
.body()
.map { it.validate() }
.catch { throw GoogleGenerativeAIException.from(it) }

suspend fun countTokens(request: CountTokensRequest): CountTokensResponse =
try {
client
.post("$DOMAIN/${requestOptions.apiVersion}/$model:countTokens") {
applyCommonConfiguration(request)
}
.also { validateResponse(it) }
.body()
} catch (e: Throwable) {
throw GoogleGenerativeAIException.from(e)
}

private fun HttpRequestBuilder.applyCommonConfiguration(request: Request) {
when (request) {
Expand Down Expand Up @@ -165,21 +179,31 @@ private inline fun <reified R : Response> HttpClient.postStream(
}

private suspend fun validateResponse(response: HttpResponse) {
if (response.status != HttpStatusCode.OK) {
val text = response.bodyAsText()
val message =
try {
JSON.decodeFromString<GRpcErrorResponse>(text).error.message
} catch (e: Throwable) {
"Unexpected Response:\n$text"
}
if (message.contains("API key not valid")) {
throw InvalidAPIKeyException(message)
if (response.status == HttpStatusCode.OK) return
val text = response.bodyAsText()
val message =
try {
JSON.decodeFromString<GRpcErrorResponse>(text).error.message
} catch (e: Throwable) {
"Unexpected Response:\n$text"
rlazo marked this conversation as resolved.
Show resolved Hide resolved
}
// TODO (b/325117891): Use a better method than string matching.
if (message == "User location is not supported for the API use.") {
throw UnsupportedUserLocationException()
}
throw ServerException(message)
if (message.contains("API key not valid")) {
throw InvalidAPIKeyException(message)
}
// TODO (b/325117891): Use a better method than string matching.
if (message == "User location is not supported for the API use.") {
throw UnsupportedUserLocationException()
}
throw ServerException(message)
}

private fun GenerateContentResponse.validate() = apply {
if ((candidates?.isEmpty() != false) && promptFeedback == null) {
throw SerializationException("Error deserializing response, found no valid fields")
}
promptFeedback?.blockReason?.let { throw PromptBlockedException(this) }
candidates
?.mapNotNull { it.finishReason }
?.firstOrNull { it != FinishReason.STOP }
?.let { throw ResponseStoppedException(this) }
rlazo marked this conversation as resolved.
Show resolved Hide resolved
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.ai.client.generativeai.common

import com.google.ai.client.generativeai.common.shared.Content
import com.google.ai.client.generativeai.common.shared.TextPart
import com.google.ai.client.generativeai.common.util.commonTest
import com.google.ai.client.generativeai.common.util.createResponses
import com.google.ai.client.generativeai.common.util.doBlocking
import com.google.ai.client.generativeai.common.util.prepareStreamingResponse
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.shouldBe
import io.kotest.matchers.string.shouldContain
import io.ktor.client.engine.mock.MockEngine
import io.ktor.client.engine.mock.respond
import io.ktor.http.HttpHeaders
import io.ktor.http.HttpStatusCode
import io.ktor.http.headersOf
import io.ktor.utils.io.ByteChannel
import io.ktor.utils.io.close
import io.ktor.utils.io.writeFully
import kotlin.time.Duration.Companion.seconds
import kotlinx.coroutines.withTimeout
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized

internal class GenerativeModelTests {
private val testTimeout = 5.seconds

@Test
fun `(generateContentStream) emits responses as they come in`() = commonTest {
val response = createResponses("The", " world", " is", " a", " beautiful", " place!")
val bytes = prepareStreamingResponse(response)

bytes.forEach { channel.writeFully(it) }
val responses = apiController.generateContentStream(textGenerateContentRequest("test"))

withTimeout(testTimeout) {
responses.collect {
it.candidates?.isEmpty() shouldBe false
channel.close()
}
}
}

@Test
fun `(generateContent) respects a custom timeout`() =
commonTest(requestOptions = RequestOptions(2.seconds)) {
shouldThrow<RequestTimeoutException> {
withTimeout(testTimeout) {
apiController.generateContent(textGenerateContentRequest("test"))
}
}
}
}

@RunWith(Parameterized::class)
internal class ModelNamingTests(private val modelName: String, private val actualName: String) {

@Test
fun `request should include right model name`() = doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
val controller = APIController("super_cool_test_key", modelName, RequestOptions(), mockEngine)

withTimeout(5.seconds) {
controller.generateContentStream(textGenerateContentRequest("cats")).collect {
it.candidates?.isEmpty() shouldBe false
channel.close()
}
}

mockEngine.requestHistory.first().url.encodedPath shouldContain actualName
}

companion object {
@JvmStatic
@Parameterized.Parameters
fun data() =
listOf(
arrayOf("gemini-pro", "models/gemini-pro"),
arrayOf("x/gemini-pro", "x/gemini-pro"),
arrayOf("models/gemini-pro", "models/gemini-pro"),
arrayOf("/modelname", "/modelname"),
arrayOf("modifiedNaming/mymodel", "modifiedNaming/mymodel"),
)
}
}

fun textGenerateContentRequest(prompt: String) =
GenerateContentRequest(
model = "unused",
contents = listOf(Content(parts = listOf(TextPart(prompt))))
)
Loading
Loading