Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for header provider #110

Merged
merged 7 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.google.ai.client.generativeai.common

import android.util.Log
import com.google.ai.client.generativeai.common.server.FinishReason
import com.google.ai.client.generativeai.common.util.decodeToFlow
import io.ktor.client.HttpClient
Expand All @@ -36,12 +37,15 @@ import io.ktor.http.ContentType
import io.ktor.http.HttpStatusCode
import io.ktor.http.contentType
import io.ktor.serialization.kotlinx.json.json
import kotlin.time.Duration
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.channelFlow
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.launch
import kotlinx.coroutines.withTimeout
import kotlinx.serialization.json.Json

val JSON = Json {
Expand All @@ -59,22 +63,25 @@ val JSON = Json {
* @property key The API key used for authentication.
* @property model The model to use for generation.
* @property apiClient The value to pass in the `x-goog-api-client` header.
* @property headerProvider A provider that generates extra headers to include in all HTTP requests.
*/
class APIController
internal constructor(
private val key: String,
model: String,
private val requestOptions: RequestOptions,
httpEngine: HttpClientEngine,
private val apiClient: String
private val apiClient: String,
private val headerProvider: HeaderProvider?
rlazo marked this conversation as resolved.
Show resolved Hide resolved
) {

constructor(
key: String,
model: String,
requestOptions: RequestOptions,
apiClient: String
) : this(key, model, requestOptions, OkHttp.create(), apiClient)
apiClient: String,
headerProvider: HeaderProvider? = null
) : this(key, model, requestOptions, OkHttp.create(), apiClient, headerProvider)

private val model = fullModelName(model)

Expand All @@ -92,6 +99,7 @@ internal constructor(
client
.post("${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:generateContent") {
applyCommonConfiguration(request)
applyHeaderProvider()
}
.also { validateResponse(it) }
.body<GenerateContentResponse>()
Expand All @@ -115,6 +123,7 @@ internal constructor(
client
.post("${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:countTokens") {
applyCommonConfiguration(request)
applyHeaderProvider()
}
.also { validateResponse(it) }
.body()
Expand All @@ -131,6 +140,77 @@ internal constructor(
header("x-goog-api-key", key)
header("x-goog-api-client", apiClient)
}

private suspend fun HttpRequestBuilder.applyHeaderProvider() {
if (headerProvider != null) {
try {
withTimeout(headerProvider.timeout) {
for ((tag, value) in headerProvider.generateHeaders()) {
header(tag, value)
}
}
} catch (e: TimeoutCancellationException) {
Log.w(TAG, "HeaderProvided timed out without generating headers, ignoring")
}
}
}

/**
* Makes a POST request to the specified [url] and returns a [Flow] of deserialized response
* objects of type [R]. The response is expected to be a stream of JSON objects that are parsed in
* real-time as they are received from the server.
*
* This function is intended for internal use within the client that handles streaming responses.
*
* Example usage:
* ```
* val client: HttpClient = HttpClient(CIO)
* val request: Request = GenerateContentRequest(...)
* val url: String = "http://example.com/stream"
*
* val responses: GenerateContentResponse = client.postStream(url) {
* setBody(request)
* contentType(ContentType.Application.Json)
* }
* responses.collect {
* println("Got a response: $it")
* }
* ```
*
* @param R The type of the response object.
* @param url The URL to which the POST request will be made.
* @param config An optional [HttpRequestBuilder] callback for request configuration.
* @return A [Flow] of response objects of type [R].
*/
private inline fun <reified R : Response> HttpClient.postStream(
url: String,
crossinline config: HttpRequestBuilder.() -> Unit = {},
): Flow<R> = channelFlow {
launch(CoroutineName("postStream")) {
preparePost(url) {
applyHeaderProvider()
config()
}
.execute {
validateResponse(it)

val channel = it.bodyAsChannel()
val flow = JSON.decodeToFlow<R>(channel)

flow.collect { send(it) }
}
}
}

companion object {
private val TAG = APIController::class.java.simpleName
}
}

interface HeaderProvider {
daymxn marked this conversation as resolved.
Show resolved Hide resolved
val timeout: Duration

suspend fun generateHeaders(): Map<String, String>
}

/**
Expand All @@ -140,50 +220,6 @@ internal constructor(
*/
private fun fullModelName(name: String): String = name.takeIf { it.contains("/") } ?: "models/$name"

/**
* Makes a POST request to the specified [url] and returns a [Flow] of deserialized response objects
* of type [R]. The response is expected to be a stream of JSON objects that are parsed in real-time
* as they are received from the server.
*
* This function is intended for internal use within the client that handles streaming responses.
*
* Example usage:
* ```
* val client: HttpClient = HttpClient(CIO)
* val request: Request = GenerateContentRequest(...)
* val url: String = "http://example.com/stream"
*
* val responses: GenerateContentResponse = client.postStream(url) {
* setBody(request)
* contentType(ContentType.Application.Json)
* }
* responses.collect {
* println("Got a response: $it")
* }
* ```
*
* @param R The type of the response object.
* @param url The URL to which the POST request will be made.
* @param config An optional [HttpRequestBuilder] callback for request configuration.
* @return A [Flow] of response objects of type [R].
*/
private inline fun <reified R : Response> HttpClient.postStream(
url: String,
crossinline config: HttpRequestBuilder.() -> Unit = {}
): Flow<R> = channelFlow {
launch(CoroutineName("postStream")) {
preparePost(url) { config() }
.execute {
validateResponse(it)

val channel = it.bodyAsChannel()
val flow = JSON.decodeToFlow<R>(channel)

flow.collect { send(it) }
}
}
}

private suspend fun validateResponse(response: HttpResponse) {
if (response.status == HttpStatusCode.OK) return
val text = response.bodyAsText()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ import io.ktor.http.headersOf
import io.ktor.utils.io.ByteChannel
import io.ktor.utils.io.close
import io.ktor.utils.io.writeFully
import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.seconds
import kotlinx.coroutines.delay
import kotlinx.coroutines.withTimeout
import kotlinx.serialization.encodeToString
import org.junit.Test
Expand Down Expand Up @@ -91,7 +94,8 @@ internal class RequestFormatTests {
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
"genai-android/${BuildConfig.VERSION_NAME}"
"genai-android/${BuildConfig.VERSION_NAME}",
null
)

withTimeout(5.seconds) {
Expand All @@ -117,7 +121,8 @@ internal class RequestFormatTests {
"gemini-pro-1.0",
RequestOptions(endpoint = "https://my.custom.endpoint"),
mockEngine,
TEST_CLIENT_ID
TEST_CLIENT_ID,
null
)

withTimeout(5.seconds) {
Expand All @@ -143,7 +148,8 @@ internal class RequestFormatTests {
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID
TEST_CLIENT_ID,
null
)

withTimeout(5.seconds) {
Expand Down Expand Up @@ -172,7 +178,8 @@ internal class RequestFormatTests {
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID
TEST_CLIENT_ID,
null
)

withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) }
Expand All @@ -195,7 +202,8 @@ internal class RequestFormatTests {
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID
TEST_CLIENT_ID,
null
)

withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) }
Expand All @@ -217,7 +225,8 @@ internal class RequestFormatTests {
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID
TEST_CLIENT_ID,
null
)

withTimeout(5.seconds) {
Expand All @@ -240,6 +249,71 @@ internal class RequestFormatTests {

requestBodyAsText shouldContainJsonKey "tool_config.function_calling_config.mode"
}

@Test
fun `headers from HeaderProvider are added to the request`() = doBlocking {
val response = JSON.encodeToString(CountTokensResponse(totalTokens = 10))
val mockEngine = MockEngine {
respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}

val testHeaderProvider =
object : HeaderProvider {
override val timeout: Duration
get() = 5.seconds

override suspend fun generateHeaders(): Map<String, String> =
mapOf("header1" to "value1", "header2" to "value2")
}

val controller =
APIController(
"super_cool_test_key",
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
testHeaderProvider
)

withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) }

mockEngine.requestHistory.first().headers["header1"] shouldBe "value1"
mockEngine.requestHistory.first().headers["header2"] shouldBe "value2"
}

@Test
fun `headers from HeaderProvider are ignored if timeout`() = doBlocking {
val response = JSON.encodeToString(CountTokensResponse(totalTokens = 10))
val mockEngine = MockEngine {
respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}

val testHeaderProvider =
object : HeaderProvider {
override val timeout: Duration
get() = 5.milliseconds

override suspend fun generateHeaders(): Map<String, String> {
delay(10.milliseconds)
return mapOf("header1" to "value1")
}
}

val controller =
APIController(
"super_cool_test_key",
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
testHeaderProvider
)

withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) }

mockEngine.requestHistory.first().headers.contains("header1") shouldBe false
}
}

@RunWith(Parameterized::class)
Expand All @@ -253,7 +327,14 @@ internal class ModelNamingTests(private val modelName: String, private val actua
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
val controller =
APIController("super_cool_test_key", modelName, RequestOptions(), mockEngine, TEST_CLIENT_ID)
APIController(
"super_cool_test_key",
modelName,
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
null
)

withTimeout(5.seconds) {
controller.generateContentStream(textGenerateContentRequest("cats")).collect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,14 @@ internal fun commonTest(
respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json"))
}
val apiController =
APIController("super_cool_test_key", "gemini-pro", requestOptions, mockEngine, TEST_CLIENT_ID)
APIController(
"super_cool_test_key",
"gemini-pro",
requestOptions,
mockEngine,
TEST_CLIENT_ID,
null
)
CommonTestScope(channel, apiController).block()
}

Expand Down
Loading