Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make CountTokenRequest contain the full request #152

Merged
merged 15 commits into from
Jun 10, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.google.ai.client.generativeai.common
import android.util.Log
import com.google.ai.client.generativeai.common.server.FinishReason
import com.google.ai.client.generativeai.common.util.decodeToFlow
import com.google.ai.client.generativeai.common.util.fullModelName
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.engine.HttpClientEngine
Expand Down Expand Up @@ -213,13 +214,6 @@ interface HeaderProvider {
suspend fun generateHeaders(): Map<String, String>
}

/**
* Ensures the model name provided has a `models/` prefix
*
* Models must be prepended with the `models/` prefix when communicating with the backend.
*/
private fun fullModelName(name: String): String = name.takeIf { it.contains("/") } ?: "models/$name"

private suspend fun validateResponse(response: HttpResponse) {
if (response.status == HttpStatusCode.OK) return
val text = response.bodyAsText()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ import com.google.ai.client.generativeai.common.client.Tool
import com.google.ai.client.generativeai.common.client.ToolConfig
import com.google.ai.client.generativeai.common.shared.Content
import com.google.ai.client.generativeai.common.shared.SafetySetting
import com.google.ai.client.generativeai.common.util.fullModelName
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.Transient

sealed interface Request

@Serializable
data class GenerateContentRequest(
@Transient val model: String? = null,
val model: String? = null,
val contents: List<Content>,
@SerialName("safety_settings") val safetySettings: List<SafetySetting>? = null,
@SerialName("generation_config") val generationConfig: GenerationConfig? = null,
Expand All @@ -39,5 +39,28 @@ data class GenerateContentRequest(
) : Request

@Serializable
data class CountTokensRequest(@Transient val model: String? = null, val contents: List<Content>) :
Request
data class CountTokensRequest(
val generateContentRequest: GenerateContentRequest? = null,
val model: String? = null,
val contents: List<Content>? = null,
val tools: List<Tool>? = null,
@SerialName("system_instruction") val systemInstruction: Content? = null,
) : Request {
companion object {
fun forGenAI(generateContentRequest: GenerateContentRequest) =
CountTokensRequest(
generateContentRequest =
generateContentRequest.model?.let {
generateContentRequest.copy(model = fullModelName(it))
} ?: generateContentRequest
)

fun forVertexAI(generateContentRequest: GenerateContentRequest) =
CountTokensRequest(
model = generateContentRequest.model?.let { fullModelName(it) },
contents = generateContentRequest.contents,
tools = generateContentRequest.tools,
systemInstruction = generateContentRequest.systemInstruction,
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.ai.client.generativeai.common.util

/**
* Ensures the model name provided has a `models/` prefix
*
* Models must be prepended with the `models/` prefix when communicating with the backend.
*/
fun fullModelName(name: String): String = name.takeIf { it.contains("/") } ?: "models/$name"
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import com.google.ai.client.generativeai.common.util.createResponses
import com.google.ai.client.generativeai.common.util.doBlocking
import com.google.ai.client.generativeai.common.util.prepareStreamingResponse
import io.kotest.assertions.json.shouldContainJsonKey
import io.kotest.assertions.json.shouldNotContainJsonKey
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.shouldBe
import io.kotest.matchers.string.shouldContain
Expand Down Expand Up @@ -135,60 +134,6 @@ internal class RequestFormatTests {
mockEngine.requestHistory.first().url.host shouldBe "my.custom.endpoint"
}

@Test
fun `generateContentRequest doesn't include the model name`() = doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
val controller =
APIController(
"super_cool_test_key",
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
null,
)

withTimeout(5.seconds) {
controller.generateContentStream(textGenerateContentRequest("cats")).collect {
it.candidates?.isEmpty() shouldBe false
channel.close()
}
}

val requestBodyAsText = (mockEngine.requestHistory.first().body as TextContent).text
requestBodyAsText shouldContainJsonKey "contents"
requestBodyAsText shouldNotContainJsonKey "model"
}

@Test
fun `countTokenRequest doesn't include the model name`() = doBlocking {
val response =
JSON.encodeToString(CountTokensResponse(totalTokens = 10, totalBillableCharacters = 10))
val mockEngine = MockEngine {
respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}

val controller =
APIController(
"super_cool_test_key",
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
null,
)

withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) }

val requestBodyAsText = (mockEngine.requestHistory.first().body as TextContent).text
requestBodyAsText shouldContainJsonKey "contents"
requestBodyAsText shouldNotContainJsonKey "model"
}

@Test
fun `client id header is set correctly in the request`() = doBlocking {
val response = JSON.encodeToString(CountTokensResponse(totalTokens = 10))
Expand Down Expand Up @@ -367,4 +312,4 @@ fun textGenerateContentRequest(prompt: String) =
)

fun textCountTokenRequest(prompt: String) =
CountTokensRequest(model = "unused", contents = listOf(Content(parts = listOf(TextPart(prompt)))))
CountTokensRequest(generateContentRequest = textGenerateContentRequest(prompt))
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import android.graphics.Bitmap
import com.google.ai.client.generativeai.common.APIController
import com.google.ai.client.generativeai.common.CountTokensRequest
import com.google.ai.client.generativeai.common.GenerateContentRequest
import com.google.ai.client.generativeai.common.util.fullModelName
import com.google.ai.client.generativeai.internal.util.toInternal
import com.google.ai.client.generativeai.internal.util.toPublic
import com.google.ai.client.generativeai.type.Content
Expand Down Expand Up @@ -85,7 +86,7 @@ internal constructor(
toolConfig: ToolConfig? = null,
systemInstruction: Content? = null,
) : this(
modelName,
fullModelName(modelName),
davidmotson marked this conversation as resolved.
Show resolved Hide resolved
apiKey,
generationConfig,
safetySettings,
Expand Down Expand Up @@ -240,7 +241,7 @@ internal constructor(
)

private fun constructCountTokensRequest(vararg prompt: Content) =
CountTokensRequest(modelName, prompt.map { it.toInternal() })
CountTokensRequest.forGenAI(constructRequest(*prompt))

private fun GenerateContentResponse.validate() = apply {
if (candidates.isEmpty() && promptFeedback == null) {
Expand Down
Loading