Skip to content

Commit

Permalink
Add support for header provider (#110)
Browse files Browse the repository at this point in the history
Co-authored-by: Daymon <[email protected]>
  • Loading branch information
rlazo and daymxn authored Apr 4, 2024
1 parent 4179344 commit b65c155
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.google.ai.client.generativeai.common

import android.util.Log
import com.google.ai.client.generativeai.common.server.FinishReason
import com.google.ai.client.generativeai.common.util.decodeToFlow
import io.ktor.client.HttpClient
Expand All @@ -36,12 +37,15 @@ import io.ktor.http.ContentType
import io.ktor.http.HttpStatusCode
import io.ktor.http.contentType
import io.ktor.serialization.kotlinx.json.json
import kotlin.time.Duration
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.channelFlow
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.launch
import kotlinx.coroutines.withTimeout
import kotlinx.serialization.json.Json

val JSON = Json {
Expand All @@ -59,22 +63,25 @@ val JSON = Json {
* @property key The API key used for authentication.
* @property model The model to use for generation.
* @property apiClient The value to pass in the `x-goog-api-client` header.
* @property headerProvider A provider that generates extra headers to include in all HTTP requests.
*/
class APIController
internal constructor(
private val key: String,
model: String,
private val requestOptions: RequestOptions,
httpEngine: HttpClientEngine,
private val apiClient: String
private val apiClient: String,
private val headerProvider: HeaderProvider?
) {

constructor(
key: String,
model: String,
requestOptions: RequestOptions,
apiClient: String
) : this(key, model, requestOptions, OkHttp.create(), apiClient)
apiClient: String,
headerProvider: HeaderProvider? = null
) : this(key, model, requestOptions, OkHttp.create(), apiClient, headerProvider)

private val model = fullModelName(model)

Expand All @@ -92,6 +99,7 @@ internal constructor(
client
.post("${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:generateContent") {
applyCommonConfiguration(request)
applyHeaderProvider()
}
.also { validateResponse(it) }
.body<GenerateContentResponse>()
Expand All @@ -115,6 +123,7 @@ internal constructor(
client
.post("${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:countTokens") {
applyCommonConfiguration(request)
applyHeaderProvider()
}
.also { validateResponse(it) }
.body()
Expand All @@ -131,6 +140,77 @@ internal constructor(
header("x-goog-api-key", key)
header("x-goog-api-client", apiClient)
}

private suspend fun HttpRequestBuilder.applyHeaderProvider() {
if (headerProvider != null) {
try {
withTimeout(headerProvider.timeout) {
for ((tag, value) in headerProvider.generateHeaders()) {
header(tag, value)
}
}
} catch (e: TimeoutCancellationException) {
Log.w(TAG, "HeaderProvided timed out without generating headers, ignoring")
}
}
}

/**
* Makes a POST request to the specified [url] and returns a [Flow] of deserialized response
* objects of type [R]. The response is expected to be a stream of JSON objects that are parsed in
* real-time as they are received from the server.
*
* This function is intended for internal use within the client that handles streaming responses.
*
* Example usage:
* ```
* val client: HttpClient = HttpClient(CIO)
* val request: Request = GenerateContentRequest(...)
* val url: String = "http://example.com/stream"
*
* val responses: GenerateContentResponse = client.postStream(url) {
* setBody(request)
* contentType(ContentType.Application.Json)
* }
* responses.collect {
* println("Got a response: $it")
* }
* ```
*
* @param R The type of the response object.
* @param url The URL to which the POST request will be made.
* @param config An optional [HttpRequestBuilder] callback for request configuration.
* @return A [Flow] of response objects of type [R].
*/
private inline fun <reified R : Response> HttpClient.postStream(
url: String,
crossinline config: HttpRequestBuilder.() -> Unit = {},
): Flow<R> = channelFlow {
launch(CoroutineName("postStream")) {
preparePost(url) {
applyHeaderProvider()
config()
}
.execute {
validateResponse(it)

val channel = it.bodyAsChannel()
val flow = JSON.decodeToFlow<R>(channel)

flow.collect { send(it) }
}
}
}

companion object {
private val TAG = APIController::class.java.simpleName
}
}

interface HeaderProvider {
val timeout: Duration

suspend fun generateHeaders(): Map<String, String>
}

/**
Expand All @@ -140,50 +220,6 @@ internal constructor(
*/
private fun fullModelName(name: String): String = name.takeIf { it.contains("/") } ?: "models/$name"

/**
* Makes a POST request to the specified [url] and returns a [Flow] of deserialized response objects
* of type [R]. The response is expected to be a stream of JSON objects that are parsed in real-time
* as they are received from the server.
*
* This function is intended for internal use within the client that handles streaming responses.
*
* Example usage:
* ```
* val client: HttpClient = HttpClient(CIO)
* val request: Request = GenerateContentRequest(...)
* val url: String = "http://example.com/stream"
*
* val responses: GenerateContentResponse = client.postStream(url) {
* setBody(request)
* contentType(ContentType.Application.Json)
* }
* responses.collect {
* println("Got a response: $it")
* }
* ```
*
* @param R The type of the response object.
* @param url The URL to which the POST request will be made.
* @param config An optional [HttpRequestBuilder] callback for request configuration.
* @return A [Flow] of response objects of type [R].
*/
private inline fun <reified R : Response> HttpClient.postStream(
url: String,
crossinline config: HttpRequestBuilder.() -> Unit = {}
): Flow<R> = channelFlow {
launch(CoroutineName("postStream")) {
preparePost(url) { config() }
.execute {
validateResponse(it)

val channel = it.bodyAsChannel()
val flow = JSON.decodeToFlow<R>(channel)

flow.collect { send(it) }
}
}
}

private suspend fun validateResponse(response: HttpResponse) {
if (response.status == HttpStatusCode.OK) return
val text = response.bodyAsText()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ import io.ktor.http.headersOf
import io.ktor.utils.io.ByteChannel
import io.ktor.utils.io.close
import io.ktor.utils.io.writeFully
import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.seconds
import kotlinx.coroutines.delay
import kotlinx.coroutines.withTimeout
import kotlinx.serialization.encodeToString
import org.junit.Test
Expand Down Expand Up @@ -91,7 +94,8 @@ internal class RequestFormatTests {
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
"genai-android/${BuildConfig.VERSION_NAME}"
"genai-android/${BuildConfig.VERSION_NAME}",
null
)

withTimeout(5.seconds) {
Expand All @@ -117,7 +121,8 @@ internal class RequestFormatTests {
"gemini-pro-1.0",
RequestOptions(endpoint = "https://my.custom.endpoint"),
mockEngine,
TEST_CLIENT_ID
TEST_CLIENT_ID,
null
)

withTimeout(5.seconds) {
Expand All @@ -143,7 +148,8 @@ internal class RequestFormatTests {
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID
TEST_CLIENT_ID,
null
)

withTimeout(5.seconds) {
Expand Down Expand Up @@ -172,7 +178,8 @@ internal class RequestFormatTests {
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID
TEST_CLIENT_ID,
null
)

withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) }
Expand All @@ -195,7 +202,8 @@ internal class RequestFormatTests {
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID
TEST_CLIENT_ID,
null
)

withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) }
Expand All @@ -217,7 +225,8 @@ internal class RequestFormatTests {
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID
TEST_CLIENT_ID,
null
)

withTimeout(5.seconds) {
Expand All @@ -240,6 +249,71 @@ internal class RequestFormatTests {

requestBodyAsText shouldContainJsonKey "tool_config.function_calling_config.mode"
}

@Test
fun `headers from HeaderProvider are added to the request`() = doBlocking {
val response = JSON.encodeToString(CountTokensResponse(totalTokens = 10))
val mockEngine = MockEngine {
respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}

val testHeaderProvider =
object : HeaderProvider {
override val timeout: Duration
get() = 5.seconds

override suspend fun generateHeaders(): Map<String, String> =
mapOf("header1" to "value1", "header2" to "value2")
}

val controller =
APIController(
"super_cool_test_key",
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
testHeaderProvider
)

withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) }

mockEngine.requestHistory.first().headers["header1"] shouldBe "value1"
mockEngine.requestHistory.first().headers["header2"] shouldBe "value2"
}

@Test
fun `headers from HeaderProvider are ignored if timeout`() = doBlocking {
val response = JSON.encodeToString(CountTokensResponse(totalTokens = 10))
val mockEngine = MockEngine {
respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}

val testHeaderProvider =
object : HeaderProvider {
override val timeout: Duration
get() = 5.milliseconds

override suspend fun generateHeaders(): Map<String, String> {
delay(10.milliseconds)
return mapOf("header1" to "value1")
}
}

val controller =
APIController(
"super_cool_test_key",
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
testHeaderProvider
)

withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) }

mockEngine.requestHistory.first().headers.contains("header1") shouldBe false
}
}

@RunWith(Parameterized::class)
Expand All @@ -253,7 +327,14 @@ internal class ModelNamingTests(private val modelName: String, private val actua
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
val controller =
APIController("super_cool_test_key", modelName, RequestOptions(), mockEngine, TEST_CLIENT_ID)
APIController(
"super_cool_test_key",
modelName,
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
null
)

withTimeout(5.seconds) {
controller.generateContentStream(textGenerateContentRequest("cats")).collect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,14 @@ internal fun commonTest(
respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json"))
}
val apiController =
APIController("super_cool_test_key", "gemini-pro", requestOptions, mockEngine, TEST_CLIENT_ID)
APIController(
"super_cool_test_key",
"gemini-pro",
requestOptions,
mockEngine,
TEST_CLIENT_ID,
null
)
CommonTestScope(channel, apiController).block()
}

Expand Down

0 comments on commit b65c155

Please sign in to comment.