Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add clientId parameter in common #107

Merged
merged 5 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,23 @@ val JSON = Json {
* Exposed primarily for DI in tests.
* @property key The API key used for authentication.
* @property model The model to use for generation.
* @property clientId The value to pass in the `x-goog-api-client` header.
*/
class APIController
internal constructor(
private val key: String,
model: String,
private val requestOptions: RequestOptions,
httpEngine: HttpClientEngine
httpEngine: HttpClientEngine,
private val clientId: String
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This field is named apiClient in Web - should we name it the same thing for consistency? I understand it's an internal field, so consistency shouldn't be a top priority.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks for the heads up @thatfiredev See #108

) {

constructor(
key: String,
model: String,
requestOptions: RequestOptions
) : this(key, model, requestOptions, OkHttp.create())
requestOptions: RequestOptions,
clientId: String
) : this(key, model, requestOptions, OkHttp.create(), clientId)
daymxn marked this conversation as resolved.
Show resolved Hide resolved

private val model = fullModelName(model)

Expand Down Expand Up @@ -127,7 +130,7 @@ internal constructor(
}
contentType(ContentType.Application.Json)
header("x-goog-api-key", key)
header("x-goog-api-client", "genai-android/${BuildConfig.VERSION_NAME}")
header("x-goog-api-client", clientId)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized

private val TEST_CLIENT_ID = "genai-android/test"

internal class APIControllerTests {
private val testTimeout = 5.seconds

Expand Down Expand Up @@ -84,7 +86,13 @@ internal class RequestFormatTests {
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
val controller =
APIController("super_cool_test_key", "gemini-pro-1.0", RequestOptions(), mockEngine)
APIController(
"super_cool_test_key",
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
"genai-android/${BuildConfig.VERSION_NAME}"
)

withTimeout(5.seconds) {
controller.generateContentStream(textGenerateContentRequest("cats")).collect {
Expand All @@ -108,7 +116,8 @@ internal class RequestFormatTests {
"super_cool_test_key",
"gemini-pro-1.0",
RequestOptions(endpoint = "https://my.custom.endpoint"),
mockEngine
mockEngine,
TEST_CLIENT_ID
)

withTimeout(5.seconds) {
Expand All @@ -129,7 +138,13 @@ internal class RequestFormatTests {
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
val controller =
APIController("super_cool_test_key", "gemini-pro-1.0", RequestOptions(), mockEngine)
APIController(
"super_cool_test_key",
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID
)

withTimeout(5.seconds) {
controller.generateContentStream(textGenerateContentRequest("cats")).collect {
Expand All @@ -151,7 +166,13 @@ internal class RequestFormatTests {
}

val controller =
APIController("super_cool_test_key", "gemini-pro-1.0", RequestOptions(), mockEngine)
APIController(
"super_cool_test_key",
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID
)

withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) }

Expand All @@ -160,6 +181,27 @@ internal class RequestFormatTests {
requestBodyAsText shouldNotContainJsonKey "model"
}

@Test
fun `client id header is set correctly in the request`() = doBlocking {
val response = JSON.encodeToString(CountTokensResponse(totalTokens = 10))
val mockEngine = MockEngine {
respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}

val controller =
APIController(
"super_cool_test_key",
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID
)

withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) }

mockEngine.requestHistory.first().headers["x-goog-api-client"] shouldBe TEST_CLIENT_ID
}

@Test
fun `ToolConfig serialization contains correct keys`() = doBlocking {
val channel = ByteChannel(autoFlush = true)
Expand Down Expand Up @@ -203,7 +245,8 @@ internal class ModelNamingTests(private val modelName: String, private val actua
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
val controller = APIController("super_cool_test_key", modelName, RequestOptions(), mockEngine)
val controller =
APIController("super_cool_test_key", modelName, RequestOptions(), mockEngine, TEST_CLIENT_ID)

withTimeout(5.seconds) {
controller.generateContentStream(textGenerateContentRequest("cats")).collect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ import java.io.File
import kotlinx.coroutines.launch
import kotlinx.serialization.encodeToString

private val TEST_CLIENT_ID = "genai-android/test"

internal fun prepareStreamingResponse(response: List<GenerateContentResponse>): List<ByteArray> =
response.map { "data: ${JSON.encodeToString(it)}$SSE_SEPARATOR".toByteArray() }

Expand Down Expand Up @@ -104,7 +106,8 @@ internal fun commonTest(
val mockEngine = MockEngine {
respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json"))
}
val apiController = APIController("super_cool_test_key", "gemini-pro", requestOptions, mockEngine)
val apiController =
APIController("super_cool_test_key", "gemini-pro", requestOptions, mockEngine, TEST_CLIENT_ID)
CommonTestScope(channel, apiController).block()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,12 @@ internal constructor(
tools,
toolConfig,
requestOptions,
APIController(apiKey, modelName, requestOptions.toInternal()),
APIController(
apiKey,
modelName,
requestOptions.toInternal(),
"genai-android/${BuildConfig.VERSION_NAME}"
),
)

/**
Expand Down
Loading