Skip to content

Commit

Permalink
Add clientId parameter in common (#107)
Browse files Browse the repository at this point in the history
This parameter is used to fill the `x-goog-api-client` header
  • Loading branch information
rlazo authored Apr 3, 2024
1 parent 91d69a9 commit 6bc885f
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,23 @@ val JSON = Json {
* Exposed primarily for DI in tests.
* @property key The API key used for authentication.
* @property model The model to use for generation.
* @property clientId The value to pass in the `x-goog-api-client` header.
*/
class APIController
internal constructor(
private val key: String,
model: String,
private val requestOptions: RequestOptions,
httpEngine: HttpClientEngine
httpEngine: HttpClientEngine,
private val clientId: String
) {

constructor(
key: String,
model: String,
requestOptions: RequestOptions
) : this(key, model, requestOptions, OkHttp.create())
requestOptions: RequestOptions,
clientId: String
) : this(key, model, requestOptions, OkHttp.create(), clientId)

private val model = fullModelName(model)

Expand Down Expand Up @@ -127,7 +130,7 @@ internal constructor(
}
contentType(ContentType.Application.Json)
header("x-goog-api-key", key)
header("x-goog-api-client", "genai-android/${BuildConfig.VERSION_NAME}")
header("x-goog-api-client", clientId)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized

private val TEST_CLIENT_ID = "genai-android/test"

internal class APIControllerTests {
private val testTimeout = 5.seconds

Expand Down Expand Up @@ -84,7 +86,13 @@ internal class RequestFormatTests {
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
val controller =
APIController("super_cool_test_key", "gemini-pro-1.0", RequestOptions(), mockEngine)
APIController(
"super_cool_test_key",
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
"genai-android/${BuildConfig.VERSION_NAME}"
)

withTimeout(5.seconds) {
controller.generateContentStream(textGenerateContentRequest("cats")).collect {
Expand All @@ -108,7 +116,8 @@ internal class RequestFormatTests {
"super_cool_test_key",
"gemini-pro-1.0",
RequestOptions(endpoint = "https://my.custom.endpoint"),
mockEngine
mockEngine,
TEST_CLIENT_ID
)

withTimeout(5.seconds) {
Expand All @@ -129,7 +138,13 @@ internal class RequestFormatTests {
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
val controller =
APIController("super_cool_test_key", "gemini-pro-1.0", RequestOptions(), mockEngine)
APIController(
"super_cool_test_key",
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID
)

withTimeout(5.seconds) {
controller.generateContentStream(textGenerateContentRequest("cats")).collect {
Expand All @@ -151,7 +166,13 @@ internal class RequestFormatTests {
}

val controller =
APIController("super_cool_test_key", "gemini-pro-1.0", RequestOptions(), mockEngine)
APIController(
"super_cool_test_key",
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID
)

withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) }

Expand All @@ -160,6 +181,27 @@ internal class RequestFormatTests {
requestBodyAsText shouldNotContainJsonKey "model"
}

@Test
fun `client id header is set correctly in the request`() = doBlocking {
val response = JSON.encodeToString(CountTokensResponse(totalTokens = 10))
val mockEngine = MockEngine {
respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}

val controller =
APIController(
"super_cool_test_key",
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID
)

withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) }

mockEngine.requestHistory.first().headers["x-goog-api-client"] shouldBe TEST_CLIENT_ID
}

@Test
fun `ToolConfig serialization contains correct keys`() = doBlocking {
val channel = ByteChannel(autoFlush = true)
Expand All @@ -169,7 +211,13 @@ internal class RequestFormatTests {
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }

val controller =
APIController("super_cool_test_key", "gemini-pro-1.0", RequestOptions(), mockEngine)
APIController(
"super_cool_test_key",
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID
)

withTimeout(5.seconds) {
controller
Expand Down Expand Up @@ -203,7 +251,8 @@ internal class ModelNamingTests(private val modelName: String, private val actua
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
val controller = APIController("super_cool_test_key", modelName, RequestOptions(), mockEngine)
val controller =
APIController("super_cool_test_key", modelName, RequestOptions(), mockEngine, TEST_CLIENT_ID)

withTimeout(5.seconds) {
controller.generateContentStream(textGenerateContentRequest("cats")).collect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ import java.io.File
import kotlinx.coroutines.launch
import kotlinx.serialization.encodeToString

private val TEST_CLIENT_ID = "genai-android/test"

internal fun prepareStreamingResponse(response: List<GenerateContentResponse>): List<ByteArray> =
response.map { "data: ${JSON.encodeToString(it)}$SSE_SEPARATOR".toByteArray() }

Expand Down Expand Up @@ -104,7 +106,8 @@ internal fun commonTest(
val mockEngine = MockEngine {
respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json"))
}
val apiController = APIController("super_cool_test_key", "gemini-pro", requestOptions, mockEngine)
val apiController =
APIController("super_cool_test_key", "gemini-pro", requestOptions, mockEngine, TEST_CLIENT_ID)
CommonTestScope(channel, apiController).block()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,12 @@ internal constructor(
tools,
toolConfig,
requestOptions,
APIController(apiKey, modelName, requestOptions.toInternal()),
APIController(
apiKey,
modelName,
requestOptions.toInternal(),
"genai-android/${BuildConfig.VERSION_NAME}"
),
)

/**
Expand Down

0 comments on commit 6bc885f

Please sign in to comment.