Skip to content

Commit

Permalink
Make CountTokenRequest contain the full request (google-gemini#152)
Browse files Browse the repository at this point in the history
Co-authored-by: David Motsonashvili <[email protected]>
  • Loading branch information
2 people authored and PatilShreyas committed Sep 21, 2024
1 parent 8f3f36c commit bfbcd55
Show file tree
Hide file tree
Showing 5 changed files with 315 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,46 @@ import dev.shreyaspatil.ai.client.generativeai.common.client.Tool
import dev.shreyaspatil.ai.client.generativeai.common.client.ToolConfig
import dev.shreyaspatil.ai.client.generativeai.common.shared.Content
import dev.shreyaspatil.ai.client.generativeai.common.shared.SafetySetting
import dev.shreyaspatil.ai.client.generativeai.common.util.fullModelName
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.Transient

sealed interface Request

@Serializable
data class GenerateContentRequest(
@Transient val model: String? = null,
val contents: List<Content>,
@SerialName("safety_settings") val safetySettings: List<SafetySetting>? = null,
@SerialName("generation_config") val generationConfig: GenerationConfig? = null,
val tools: List<Tool>? = null,
@SerialName("tool_config") var toolConfig: ToolConfig? = null,
@SerialName("system_instruction") val systemInstruction: Content? = null,
val model: String? = null,
val contents: List<Content>,
@SerialName("safety_settings") val safetySettings: List<SafetySetting>? = null,
@SerialName("generation_config") val generationConfig: GenerationConfig? = null,
val tools: List<Tool>? = null,
@SerialName("tool_config") var toolConfig: ToolConfig? = null,
@SerialName("system_instruction") val systemInstruction: Content? = null,
) : Request

@Serializable
data class CountTokensRequest(@Transient val model: String? = null, val contents: List<Content>) :
Request
data class CountTokensRequest(
val generateContentRequest: GenerateContentRequest? = null,
val model: String? = null,
val contents: List<Content>? = null,
val tools: List<Tool>? = null,
@SerialName("system_instruction") val systemInstruction: Content? = null,
) : Request {
companion object {
fun forGenAI(generateContentRequest: GenerateContentRequest) =
CountTokensRequest(
generateContentRequest =
generateContentRequest.model?.let {
generateContentRequest.copy(model = fullModelName(it))
} ?: generateContentRequest
)

fun forVertexAI(generateContentRequest: GenerateContentRequest) =
CountTokensRequest(
model = generateContentRequest.model?.let { fullModelName(it) },
contents = generateContentRequest.contents,
tools = generateContentRequest.tools,
systemInstruction = generateContentRequest.systemInstruction,
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.google.ai.client.generativeai.common
import android.util.Log
import com.google.ai.client.generativeai.common.server.FinishReason
import com.google.ai.client.generativeai.common.util.decodeToFlow
import com.google.ai.client.generativeai.common.util.fullModelName
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.engine.HttpClientEngine
Expand Down Expand Up @@ -213,13 +214,6 @@ interface HeaderProvider {
suspend fun generateHeaders(): Map<String, String>
}

/**
* Ensures the model name provided has a `models/` prefix
*
* Models must be prepended with the `models/` prefix when communicating with the backend.
*/
private fun fullModelName(name: String): String = name.takeIf { it.contains("/") } ?: "models/$name"

private suspend fun validateResponse(response: HttpResponse) {
if (response.status == HttpStatusCode.OK) return
val text = response.bodyAsText()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.ai.client.generativeai.common.util

/**
* Ensures the model name provided has a `models/` prefix
*
* Models must be prepended with the `models/` prefix when communicating with the backend.
*/
fun fullModelName(name: String): String = name.takeIf { it.contains("/") } ?: "models/$name"
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import com.google.ai.client.generativeai.common.util.createResponses
import com.google.ai.client.generativeai.common.util.doBlocking
import com.google.ai.client.generativeai.common.util.prepareStreamingResponse
import io.kotest.assertions.json.shouldContainJsonKey
import io.kotest.assertions.json.shouldNotContainJsonKey
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.shouldBe
import io.kotest.matchers.string.shouldContain
Expand Down Expand Up @@ -135,60 +134,6 @@ internal class RequestFormatTests {
mockEngine.requestHistory.first().url.host shouldBe "my.custom.endpoint"
}

@Test
fun `generateContentRequest doesn't include the model name`() = doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
val controller =
APIController(
"super_cool_test_key",
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
null,
)

withTimeout(5.seconds) {
controller.generateContentStream(textGenerateContentRequest("cats")).collect {
it.candidates?.isEmpty() shouldBe false
channel.close()
}
}

val requestBodyAsText = (mockEngine.requestHistory.first().body as TextContent).text
requestBodyAsText shouldContainJsonKey "contents"
requestBodyAsText shouldNotContainJsonKey "model"
}

@Test
fun `countTokenRequest doesn't include the model name`() = doBlocking {
val response =
JSON.encodeToString(CountTokensResponse(totalTokens = 10, totalBillableCharacters = 10))
val mockEngine = MockEngine {
respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}

val controller =
APIController(
"super_cool_test_key",
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
null,
)

withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) }

val requestBodyAsText = (mockEngine.requestHistory.first().body as TextContent).text
requestBodyAsText shouldContainJsonKey "contents"
requestBodyAsText shouldNotContainJsonKey "model"
}

@Test
fun `client id header is set correctly in the request`() = doBlocking {
val response = JSON.encodeToString(CountTokensResponse(totalTokens = 10))
Expand Down Expand Up @@ -367,4 +312,4 @@ fun textGenerateContentRequest(prompt: String) =
)

fun textCountTokenRequest(prompt: String) =
CountTokensRequest(model = "unused", contents = listOf(Content(parts = listOf(TextPart(prompt)))))
CountTokensRequest(generateContentRequest = textGenerateContentRequest(prompt))
Loading

0 comments on commit bfbcd55

Please sign in to comment.