Skip to content

Commit

Permalink
Ensure header is provided from coroutine
Browse files Browse the repository at this point in the history
  • Loading branch information
daymxn committed Apr 4, 2024
1 parent 0b2165f commit 1abad6a
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 80 deletions.
1 change: 1 addition & 0 deletions common/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ android {
}
kotlinOptions {
jvmTarget = "17"
freeCompilerArgs += "-Xcontext-receivers"
}

testOptions {
Expand Down
2 changes: 1 addition & 1 deletion common/gradle.properties
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version=0.1.0
version=0.1.1
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,16 @@ import io.ktor.http.ContentType
import io.ktor.http.HttpStatusCode
import io.ktor.http.contentType
import io.ktor.serialization.kotlinx.json.json
import kotlin.time.Duration
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.channelFlow
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withTimeout
import kotlinx.serialization.json.Json
import kotlin.time.Duration

val JSON = Json {
ignoreUnknownKeys = true
Expand Down Expand Up @@ -100,6 +98,7 @@ internal constructor(
client
.post("${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:generateContent") {
applyCommonConfiguration(request)
applyHeaderProvider()
}
.also { validateResponse(it) }
.body<GenerateContentResponse>()
Expand All @@ -123,6 +122,7 @@ internal constructor(
client
.post("${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:countTokens") {
applyCommonConfiguration(request)
applyHeaderProvider()
}
.also { validateResponse(it) }
.body()
Expand All @@ -138,30 +138,77 @@ internal constructor(
contentType(ContentType.Application.Json)
header("x-goog-api-key", key)
header("x-goog-api-client", apiClient)
}

// Obtain additional headers from provider
private suspend fun HttpRequestBuilder.applyHeaderProvider() {
if (headerProvider != null) {
runBlocking(Dispatchers.IO) {
try {
withTimeout(headerProvider.timeout) {
for ((tag, value) in headerProvider.generateHeaders()) {
header(tag, value)
}
try {
withTimeout(headerProvider.timeout) {
for ((tag, value) in headerProvider.generateHeaders()) {
header(tag, value)
}
} catch (e: TimeoutCancellationException) {
Log.w(TAG,"HeaderProvided timed out without generating headers, ignoring")
}
} catch (e: TimeoutCancellationException) {
Log.w(TAG, "HeaderProvided timed out without generating headers, ignoring")
}
}
}

/**
* Makes a POST request to the specified [url] and returns a [Flow] of deserialized response
* objects of type [R]. The response is expected to be a stream of JSON objects that are parsed in
* real-time as they are received from the server.
*
* This function is intended for internal use within the client that handles streaming responses.
*
* Example usage:
* ```
* val client: HttpClient = HttpClient(CIO)
* val request: Request = GenerateContentRequest(...)
* val url: String = "http://example.com/stream"
*
* val responses: GenerateContentResponse = client.postStream(url) {
* setBody(request)
* contentType(ContentType.Application.Json)
* }
* responses.collect {
* println("Got a response: $it")
* }
* ```
*
* @param R The type of the response object.
* @param url The URL to which the POST request will be made.
* @param config An optional [HttpRequestBuilder] callback for request configuration.
* @return A [Flow] of response objects of type [R].
*/
private inline fun <reified R : Response> HttpClient.postStream(
url: String,
crossinline config: HttpRequestBuilder.() -> Unit = {},
): Flow<R> = channelFlow {
launch(CoroutineName("postStream")) {
preparePost(url) {
applyHeaderProvider()
config()
}
.execute {
validateResponse(it)

val channel = it.bodyAsChannel()
val flow = JSON.decodeToFlow<R>(channel)

flow.collect { send(it) }
}
}
}

companion object {
private val TAG = APIController::class.java.simpleName
}
}

interface HeaderProvider {
val timeout: Duration

suspend fun generateHeaders(): Map<String, String>
}

Expand All @@ -172,50 +219,6 @@ interface HeaderProvider {
*/
private fun fullModelName(name: String): String = name.takeIf { it.contains("/") } ?: "models/$name"

/**
* Makes a POST request to the specified [url] and returns a [Flow] of deserialized response objects
* of type [R]. The response is expected to be a stream of JSON objects that are parsed in real-time
* as they are received from the server.
*
* This function is intended for internal use within the client that handles streaming responses.
*
* Example usage:
* ```
* val client: HttpClient = HttpClient(CIO)
* val request: Request = GenerateContentRequest(...)
* val url: String = "http://example.com/stream"
*
* val responses: GenerateContentResponse = client.postStream(url) {
* setBody(request)
* contentType(ContentType.Application.Json)
* }
* responses.collect {
* println("Got a response: $it")
* }
* ```
*
* @param R The type of the response object.
* @param url The URL to which the POST request will be made.
* @param config An optional [HttpRequestBuilder] callback for request configuration.
* @return A [Flow] of response objects of type [R].
*/
private inline fun <reified R : Response> HttpClient.postStream(
url: String,
crossinline config: HttpRequestBuilder.() -> Unit = {}
): Flow<R> = channelFlow {
launch(CoroutineName("postStream")) {
preparePost(url) { config() }
.execute {
validateResponse(it)

val channel = it.bodyAsChannel()
val flow = JSON.decodeToFlow<R>(channel)

flow.collect { send(it) }
}
}
}

private suspend fun validateResponse(response: HttpResponse) {
if (response.status == HttpStatusCode.OK) return
val text = response.bodyAsText()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ import io.ktor.http.headersOf
import io.ktor.utils.io.ByteChannel
import io.ktor.utils.io.close
import io.ktor.utils.io.writeFully
import kotlinx.coroutines.delay
import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.seconds
import kotlinx.coroutines.delay
import kotlinx.coroutines.withTimeout
import kotlinx.serialization.encodeToString
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds

private val TEST_CLIENT_ID = "genai-android/test"

Expand Down Expand Up @@ -78,7 +78,6 @@ internal class APIControllerTests {
}
}
}

}

internal class RequestFormatTests {
Expand Down Expand Up @@ -258,15 +257,14 @@ internal class RequestFormatTests {
respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}

val testHeaderProvider = object : HeaderProvider {
override val timeout: Duration
get() = 5.seconds
val testHeaderProvider =
object : HeaderProvider {
override val timeout: Duration
get() = 5.seconds

override suspend fun generateHeaders(): Map<String, String> = mapOf(
"header1" to "value1",
"header2" to "value2"
)
}
override suspend fun generateHeaders(): Map<String, String> =
mapOf("header1" to "value1", "header2" to "value2")
}

val controller =
APIController(
Expand All @@ -291,17 +289,16 @@ internal class RequestFormatTests {
respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}

val testHeaderProvider = object : HeaderProvider {
override val timeout: Duration
get() = 5.milliseconds
val testHeaderProvider =
object : HeaderProvider {
override val timeout: Duration
get() = 5.milliseconds

override suspend fun generateHeaders(): Map<String, String> {
delay(10.milliseconds)
return mapOf(
"header1" to "value1"
)
override suspend fun generateHeaders(): Map<String, String> {
delay(10.milliseconds)
return mapOf("header1" to "value1")
}
}
}

val controller =
APIController(
Expand Down Expand Up @@ -330,7 +327,14 @@ internal class ModelNamingTests(private val modelName: String, private val actua
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
val controller =
APIController("super_cool_test_key", modelName, RequestOptions(), mockEngine, TEST_CLIENT_ID, null)
APIController(
"super_cool_test_key",
modelName,
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
null
)

withTimeout(5.seconds) {
controller.generateContentStream(textGenerateContentRequest("cats")).collect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,14 @@ internal fun commonTest(
respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json"))
}
val apiController =
APIController("super_cool_test_key", "gemini-pro", requestOptions, mockEngine, TEST_CLIENT_ID, null)
APIController(
"super_cool_test_key",
"gemini-pro",
requestOptions,
mockEngine,
TEST_CLIENT_ID,
null
)
CommonTestScope(channel, apiController).block()
}

Expand Down

0 comments on commit 1abad6a

Please sign in to comment.