Skip to content

Commit

Permalink
Make CountTokenRequest contain the full request (#152)
Browse files Browse the repository at this point in the history
Co-authored-by: David Motsonashvili <[email protected]>
  • Loading branch information
davidmotson and David Motsonashvili authored Jun 10, 2024
1 parent 28016d1 commit 9c4aaf0
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.google.ai.client.generativeai.common
import android.util.Log
import com.google.ai.client.generativeai.common.server.FinishReason
import com.google.ai.client.generativeai.common.util.decodeToFlow
import com.google.ai.client.generativeai.common.util.fullModelName
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.engine.HttpClientEngine
Expand Down Expand Up @@ -213,13 +214,6 @@ interface HeaderProvider {
suspend fun generateHeaders(): Map<String, String>
}

/**
* Ensures the model name provided has a `models/` prefix
*
* Models must be prepended with the `models/` prefix when communicating with the backend.
*/
private fun fullModelName(name: String): String = name.takeIf { it.contains("/") } ?: "models/$name"

private suspend fun validateResponse(response: HttpResponse) {
if (response.status == HttpStatusCode.OK) return
val text = response.bodyAsText()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ import com.google.ai.client.generativeai.common.client.Tool
import com.google.ai.client.generativeai.common.client.ToolConfig
import com.google.ai.client.generativeai.common.shared.Content
import com.google.ai.client.generativeai.common.shared.SafetySetting
import com.google.ai.client.generativeai.common.util.fullModelName
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.Transient

sealed interface Request

@Serializable
data class GenerateContentRequest(
@Transient val model: String? = null,
val model: String? = null,
val contents: List<Content>,
@SerialName("safety_settings") val safetySettings: List<SafetySetting>? = null,
@SerialName("generation_config") val generationConfig: GenerationConfig? = null,
Expand All @@ -39,5 +39,28 @@ data class GenerateContentRequest(
) : Request

@Serializable
data class CountTokensRequest(@Transient val model: String? = null, val contents: List<Content>) :
Request
data class CountTokensRequest(
val generateContentRequest: GenerateContentRequest? = null,
val model: String? = null,
val contents: List<Content>? = null,
val tools: List<Tool>? = null,
@SerialName("system_instruction") val systemInstruction: Content? = null,
) : Request {
companion object {
fun forGenAI(generateContentRequest: GenerateContentRequest) =
CountTokensRequest(
generateContentRequest =
generateContentRequest.model?.let {
generateContentRequest.copy(model = fullModelName(it))
} ?: generateContentRequest
)

fun forVertexAI(generateContentRequest: GenerateContentRequest) =
CountTokensRequest(
model = generateContentRequest.model?.let { fullModelName(it) },
contents = generateContentRequest.contents,
tools = generateContentRequest.tools,
systemInstruction = generateContentRequest.systemInstruction,
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.ai.client.generativeai.common.util

/**
* Ensures the model name provided has a `models/` prefix
*
* Models must be prepended with the `models/` prefix when communicating with the backend.
*/
fun fullModelName(name: String): String = name.takeIf { it.contains("/") } ?: "models/$name"
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import com.google.ai.client.generativeai.common.util.createResponses
import com.google.ai.client.generativeai.common.util.doBlocking
import com.google.ai.client.generativeai.common.util.prepareStreamingResponse
import io.kotest.assertions.json.shouldContainJsonKey
import io.kotest.assertions.json.shouldNotContainJsonKey
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.shouldBe
import io.kotest.matchers.string.shouldContain
Expand Down Expand Up @@ -135,60 +134,6 @@ internal class RequestFormatTests {
mockEngine.requestHistory.first().url.host shouldBe "my.custom.endpoint"
}

@Test
fun `generateContentRequest doesn't include the model name`() = doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
val controller =
APIController(
"super_cool_test_key",
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
null,
)

withTimeout(5.seconds) {
controller.generateContentStream(textGenerateContentRequest("cats")).collect {
it.candidates?.isEmpty() shouldBe false
channel.close()
}
}

val requestBodyAsText = (mockEngine.requestHistory.first().body as TextContent).text
requestBodyAsText shouldContainJsonKey "contents"
requestBodyAsText shouldNotContainJsonKey "model"
}

@Test
fun `countTokenRequest doesn't include the model name`() = doBlocking {
val response =
JSON.encodeToString(CountTokensResponse(totalTokens = 10, totalBillableCharacters = 10))
val mockEngine = MockEngine {
respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}

val controller =
APIController(
"super_cool_test_key",
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
null,
)

withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) }

val requestBodyAsText = (mockEngine.requestHistory.first().body as TextContent).text
requestBodyAsText shouldContainJsonKey "contents"
requestBodyAsText shouldNotContainJsonKey "model"
}

@Test
fun `client id header is set correctly in the request`() = doBlocking {
val response = JSON.encodeToString(CountTokensResponse(totalTokens = 10))
Expand Down Expand Up @@ -367,4 +312,4 @@ fun textGenerateContentRequest(prompt: String) =
)

fun textCountTokenRequest(prompt: String) =
CountTokensRequest(model = "unused", contents = listOf(Content(parts = listOf(TextPart(prompt)))))
CountTokensRequest(generateContentRequest = textGenerateContentRequest(prompt))
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import android.graphics.Bitmap
import com.google.ai.client.generativeai.common.APIController
import com.google.ai.client.generativeai.common.CountTokensRequest
import com.google.ai.client.generativeai.common.GenerateContentRequest
import com.google.ai.client.generativeai.common.util.fullModelName
import com.google.ai.client.generativeai.internal.util.toInternal
import com.google.ai.client.generativeai.internal.util.toPublic
import com.google.ai.client.generativeai.type.Content
Expand Down Expand Up @@ -85,7 +86,7 @@ internal constructor(
toolConfig: ToolConfig? = null,
systemInstruction: Content? = null,
) : this(
modelName,
fullModelName(modelName),
apiKey,
generationConfig,
safetySettings,
Expand Down Expand Up @@ -240,7 +241,7 @@ internal constructor(
)

private fun constructCountTokensRequest(vararg prompt: Content) =
CountTokensRequest(modelName, prompt.map { it.toInternal() })
CountTokensRequest.forGenAI(constructRequest(*prompt))

private fun GenerateContentResponse.validate() = apply {
if (candidates.isEmpty() && promptFeedback == null) {
Expand Down

0 comments on commit 9c4aaf0

Please sign in to comment.