diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt index e9cbae41..257bdfca 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt @@ -59,20 +59,23 @@ val JSON = Json { * Exposed primarily for DI in tests. * @property key The API key used for authentication. * @property model The model to use for generation. + * @property clientId The value to pass in the `x-goog-api-client` header. */ class APIController internal constructor( private val key: String, model: String, private val requestOptions: RequestOptions, - httpEngine: HttpClientEngine + httpEngine: HttpClientEngine, + private val clientId: String ) { constructor( key: String, model: String, - requestOptions: RequestOptions - ) : this(key, model, requestOptions, OkHttp.create()) + requestOptions: RequestOptions, + clientId: String + ) : this(key, model, requestOptions, OkHttp.create(), clientId) private val model = fullModelName(model) @@ -127,7 +130,7 @@ internal constructor( } contentType(ContentType.Application.Json) header("x-goog-api-key", key) - header("x-goog-api-client", "genai-android/${BuildConfig.VERSION_NAME}") + header("x-goog-api-client", clientId) } } diff --git a/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt index 5564f53a..959f10dd 100644 --- a/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt +++ b/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt @@ -45,6 +45,8 @@ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.Parameterized +private val TEST_CLIENT_ID = "genai-android/test" + internal class APIControllerTests { private val testTimeout = 5.seconds @@ -84,7 +86,13 @@ internal class RequestFormatTests { } prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) } val controller = - APIController("super_cool_test_key", "gemini-pro-1.0", RequestOptions(), mockEngine) + APIController( + "super_cool_test_key", + "gemini-pro-1.0", + RequestOptions(), + mockEngine, + "genai-android/${BuildConfig.VERSION_NAME}" + ) withTimeout(5.seconds) { controller.generateContentStream(textGenerateContentRequest("cats")).collect { @@ -108,7 +116,8 @@ internal class RequestFormatTests { "super_cool_test_key", "gemini-pro-1.0", RequestOptions(endpoint = "https://my.custom.endpoint"), - mockEngine + mockEngine, + TEST_CLIENT_ID ) withTimeout(5.seconds) { @@ -129,7 +138,13 @@ internal class RequestFormatTests { } prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) } val controller = - APIController("super_cool_test_key", "gemini-pro-1.0", RequestOptions(), mockEngine) + APIController( + "super_cool_test_key", + "gemini-pro-1.0", + RequestOptions(), + mockEngine, + TEST_CLIENT_ID + ) withTimeout(5.seconds) { controller.generateContentStream(textGenerateContentRequest("cats")).collect { @@ -152,7 +167,13 @@ internal class RequestFormatTests { } val controller = - APIController("super_cool_test_key", "gemini-pro-1.0", RequestOptions(), mockEngine) + APIController( + "super_cool_test_key", + "gemini-pro-1.0", + RequestOptions(), + mockEngine, + TEST_CLIENT_ID + ) withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) } @@ -161,6 +182,27 @@ internal class RequestFormatTests { requestBodyAsText shouldNotContainJsonKey "model" } + @Test + fun `client id header is set correctly in the request`() = doBlocking { + val response = JSON.encodeToString(CountTokensResponse(totalTokens = 10)) + val mockEngine = MockEngine { + respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) + } + + val controller = + APIController( + "super_cool_test_key", + "gemini-pro-1.0", + RequestOptions(), + mockEngine, + TEST_CLIENT_ID + ) + + withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) } + + mockEngine.requestHistory.first().headers["x-goog-api-client"] shouldBe TEST_CLIENT_ID + } + @Test fun `ToolConfig serialization contains correct keys`() = doBlocking { val channel = ByteChannel(autoFlush = true) @@ -170,7 +212,13 @@ internal class RequestFormatTests { prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) } val controller = - APIController("super_cool_test_key", "gemini-pro-1.0", RequestOptions(), mockEngine) + APIController( + "super_cool_test_key", + "gemini-pro-1.0", + RequestOptions(), + mockEngine, + TEST_CLIENT_ID + ) withTimeout(5.seconds) { controller @@ -204,7 +252,8 @@ internal class ModelNamingTests(private val modelName: String, private val actua respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) } prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) } - val controller = APIController("super_cool_test_key", modelName, RequestOptions(), mockEngine) + val controller = + APIController("super_cool_test_key", modelName, RequestOptions(), mockEngine, TEST_CLIENT_ID) withTimeout(5.seconds) { controller.generateContentStream(textGenerateContentRequest("cats")).collect { diff --git a/common/src/test/java/com/google/ai/client/generativeai/common/util/tests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/util/tests.kt index d753a9bf..0acb8d75 100644 --- a/common/src/test/java/com/google/ai/client/generativeai/common/util/tests.kt +++ b/common/src/test/java/com/google/ai/client/generativeai/common/util/tests.kt @@ -38,6 +38,8 @@ import java.io.File import kotlinx.coroutines.launch import kotlinx.serialization.encodeToString +private val TEST_CLIENT_ID = "genai-android/test" + internal fun prepareStreamingResponse(response: List<GenerateContentResponse>): List<ByteArray> = response.map { "data: ${JSON.encodeToString(it)}$SSE_SEPARATOR".toByteArray() } @@ -104,7 +106,8 @@ internal fun commonTest( val mockEngine = MockEngine { respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json")) } - val apiController = APIController("super_cool_test_key", "gemini-pro", requestOptions, mockEngine) + val apiController = + APIController("super_cool_test_key", "gemini-pro", requestOptions, mockEngine, TEST_CLIENT_ID) CommonTestScope(channel, apiController).block() } diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt index a2f285f1..d1639bb2 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt @@ -90,7 +90,12 @@ internal constructor( tools, toolConfig, requestOptions, - APIController(apiKey, modelName, requestOptions.toInternal()), + APIController( + apiKey, + modelName, + requestOptions.toInternal(), + "genai-android/${BuildConfig.VERSION_NAME}" + ), ) /**