From 9c4aaf0b224a1159a06d8ae7c61d32868dfe1bfc Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Mon, 10 Jun 2024 10:08:34 -0700 Subject: [PATCH] Make CountTokenRequest contain the full request (#152) Co-authored-by: David Motsonashvili --- .../generativeai/common/APIController.kt | 8 +-- .../ai/client/generativeai/common/Request.kt | 31 ++++++++-- .../client/generativeai/common/util/util.kt | 24 ++++++++ .../generativeai/common/APIControllerTests.kt | 57 +------------------ .../ai/client/generativeai/GenerativeModel.kt | 5 +- 5 files changed, 56 insertions(+), 69 deletions(-) create mode 100644 common/src/main/kotlin/com/google/ai/client/generativeai/common/util/util.kt diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt index ebac13fd..b23faa76 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt @@ -19,6 +19,7 @@ package com.google.ai.client.generativeai.common import android.util.Log import com.google.ai.client.generativeai.common.server.FinishReason import com.google.ai.client.generativeai.common.util.decodeToFlow +import com.google.ai.client.generativeai.common.util.fullModelName import io.ktor.client.HttpClient import io.ktor.client.call.body import io.ktor.client.engine.HttpClientEngine @@ -213,13 +214,6 @@ interface HeaderProvider { suspend fun generateHeaders(): Map } -/** - * Ensures the model name provided has a `models/` prefix - * - * Models must be prepended with the `models/` prefix when communicating with the backend. - */ -private fun fullModelName(name: String): String = name.takeIf { it.contains("/") } ?: "models/$name" - private suspend fun validateResponse(response: HttpResponse) { if (response.status == HttpStatusCode.OK) return val text = response.bodyAsText() diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt index 09171a15..b252b2e1 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt @@ -21,15 +21,15 @@ import com.google.ai.client.generativeai.common.client.Tool import com.google.ai.client.generativeai.common.client.ToolConfig import com.google.ai.client.generativeai.common.shared.Content import com.google.ai.client.generativeai.common.shared.SafetySetting +import com.google.ai.client.generativeai.common.util.fullModelName import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable -import kotlinx.serialization.Transient sealed interface Request @Serializable data class GenerateContentRequest( - @Transient val model: String? = null, + val model: String? = null, val contents: List, @SerialName("safety_settings") val safetySettings: List? = null, @SerialName("generation_config") val generationConfig: GenerationConfig? = null, @@ -39,5 +39,28 @@ data class GenerateContentRequest( ) : Request @Serializable -data class CountTokensRequest(@Transient val model: String? = null, val contents: List) : - Request +data class CountTokensRequest( + val generateContentRequest: GenerateContentRequest? = null, + val model: String? = null, + val contents: List? = null, + val tools: List? = null, + @SerialName("system_instruction") val systemInstruction: Content? = null, +) : Request { + companion object { + fun forGenAI(generateContentRequest: GenerateContentRequest) = + CountTokensRequest( + generateContentRequest = + generateContentRequest.model?.let { + generateContentRequest.copy(model = fullModelName(it)) + } ?: generateContentRequest + ) + + fun forVertexAI(generateContentRequest: GenerateContentRequest) = + CountTokensRequest( + model = generateContentRequest.model?.let { fullModelName(it) }, + contents = generateContentRequest.contents, + tools = generateContentRequest.tools, + systemInstruction = generateContentRequest.systemInstruction, + ) + } +} diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/util.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/util.kt new file mode 100644 index 00000000..56a8a97d --- /dev/null +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/util.kt @@ -0,0 +1,24 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.common.util + +/** + * Ensures the model name provided has a `models/` prefix + * + * Models must be prepended with the `models/` prefix when communicating with the backend. + */ +fun fullModelName(name: String): String = name.takeIf { it.contains("/") } ?: "models/$name" diff --git a/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt index 31171ece..776e0d0c 100644 --- a/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt +++ b/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt @@ -25,7 +25,6 @@ import com.google.ai.client.generativeai.common.util.createResponses import com.google.ai.client.generativeai.common.util.doBlocking import com.google.ai.client.generativeai.common.util.prepareStreamingResponse import io.kotest.assertions.json.shouldContainJsonKey -import io.kotest.assertions.json.shouldNotContainJsonKey import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.shouldBe import io.kotest.matchers.string.shouldContain @@ -135,60 +134,6 @@ internal class RequestFormatTests { mockEngine.requestHistory.first().url.host shouldBe "my.custom.endpoint" } - @Test - fun `generateContentRequest doesn't include the model name`() = doBlocking { - val channel = ByteChannel(autoFlush = true) - val mockEngine = MockEngine { - respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) - } - prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) } - val controller = - APIController( - "super_cool_test_key", - "gemini-pro-1.0", - RequestOptions(), - mockEngine, - TEST_CLIENT_ID, - null, - ) - - withTimeout(5.seconds) { - controller.generateContentStream(textGenerateContentRequest("cats")).collect { - it.candidates?.isEmpty() shouldBe false - channel.close() - } - } - - val requestBodyAsText = (mockEngine.requestHistory.first().body as TextContent).text - requestBodyAsText shouldContainJsonKey "contents" - requestBodyAsText shouldNotContainJsonKey "model" - } - - @Test - fun `countTokenRequest doesn't include the model name`() = doBlocking { - val response = - JSON.encodeToString(CountTokensResponse(totalTokens = 10, totalBillableCharacters = 10)) - val mockEngine = MockEngine { - respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) - } - - val controller = - APIController( - "super_cool_test_key", - "gemini-pro-1.0", - RequestOptions(), - mockEngine, - TEST_CLIENT_ID, - null, - ) - - withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) } - - val requestBodyAsText = (mockEngine.requestHistory.first().body as TextContent).text - requestBodyAsText shouldContainJsonKey "contents" - requestBodyAsText shouldNotContainJsonKey "model" - } - @Test fun `client id header is set correctly in the request`() = doBlocking { val response = JSON.encodeToString(CountTokensResponse(totalTokens = 10)) @@ -367,4 +312,4 @@ fun textGenerateContentRequest(prompt: String) = ) fun textCountTokenRequest(prompt: String) = - CountTokensRequest(model = "unused", contents = listOf(Content(parts = listOf(TextPart(prompt))))) + CountTokensRequest(generateContentRequest = textGenerateContentRequest(prompt)) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt index 6a4618a9..6ea9df87 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt @@ -20,6 +20,7 @@ import android.graphics.Bitmap import com.google.ai.client.generativeai.common.APIController import com.google.ai.client.generativeai.common.CountTokensRequest import com.google.ai.client.generativeai.common.GenerateContentRequest +import com.google.ai.client.generativeai.common.util.fullModelName import com.google.ai.client.generativeai.internal.util.toInternal import com.google.ai.client.generativeai.internal.util.toPublic import com.google.ai.client.generativeai.type.Content @@ -85,7 +86,7 @@ internal constructor( toolConfig: ToolConfig? = null, systemInstruction: Content? = null, ) : this( - modelName, + fullModelName(modelName), apiKey, generationConfig, safetySettings, @@ -240,7 +241,7 @@ internal constructor( ) private fun constructCountTokensRequest(vararg prompt: Content) = - CountTokensRequest(modelName, prompt.map { it.toInternal() }) + CountTokensRequest.forGenAI(constructRequest(*prompt)) private fun GenerateContentResponse.validate() = apply { if (candidates.isEmpty() && promptFeedback == null) {