Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement RequestOptions (#52) #55

Merged
merged 1 commit into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .changes/behavior-attention-channel-bottle.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"MINOR","changes":["Add RequestOptions; configuration points for backend implementation details such as api version and timeout."]}
2 changes: 1 addition & 1 deletion generativeai/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ dependencies {

implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1")
implementation("androidx.core:core-ktx:1.12.0")
implementation("org.slf4j:slf4j-android:1.7.36")
implementation("org.slf4j:slf4j-nop:2.0.9")
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.7.3")
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-reactive:1.7.3")
implementation("org.reactivestreams:reactive-streams:1.0.3")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import com.google.ai.client.generativeai.type.GenerateContentResponse
import com.google.ai.client.generativeai.type.GenerationConfig
import com.google.ai.client.generativeai.type.GoogleGenerativeAIException
import com.google.ai.client.generativeai.type.PromptBlockedException
import com.google.ai.client.generativeai.type.RequestOptions
import com.google.ai.client.generativeai.type.ResponseStoppedException
import com.google.ai.client.generativeai.type.SafetySetting
import com.google.ai.client.generativeai.type.SerializationException
Expand All @@ -45,13 +46,15 @@ import kotlinx.coroutines.flow.map
* @property generationConfig configuration parameters to use for content generation
* @property safetySettings the safety bounds to use during alongside prompts during content
* generation
* @property requestOptions configuration options to utilize during backend communication
*/
class GenerativeModel
internal constructor(
val modelName: String,
val apiKey: String,
val generationConfig: GenerationConfig? = null,
val safetySettings: List<SafetySetting>? = null,
val requestOptions: RequestOptions = RequestOptions(),
private val controller: APIController
) {

Expand All @@ -61,7 +64,15 @@ internal constructor(
apiKey: String,
generationConfig: GenerationConfig? = null,
safetySettings: List<SafetySetting>? = null,
) : this(modelName, apiKey, generationConfig, safetySettings, APIController(apiKey, modelName))
requestOptions: RequestOptions = RequestOptions(),
) : this(
modelName,
apiKey,
generationConfig,
safetySettings,
requestOptions,
APIController(apiKey, modelName, requestOptions.apiVersion, requestOptions.timeout)
)

/**
* Generates a response from the backend with the provided [Content]s.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,15 @@ import io.ktor.http.ContentType
import io.ktor.http.HttpStatusCode
import io.ktor.http.contentType
import io.ktor.serialization.kotlinx.json.json
import kotlin.time.Duration
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.channelFlow
import kotlinx.coroutines.flow.timeout
import kotlinx.coroutines.launch
import kotlinx.serialization.json.Json

// TODO: Should these stay here or be moved elsewhere?
internal const val DOMAIN = "https://generativelanguage.googleapis.com/v1"
internal const val DOMAIN = "https://generativelanguage.googleapis.com"

internal val JSON = Json {
ignoreUnknownKeys = true
Expand All @@ -60,42 +61,46 @@ internal val JSON = Json {
* Exposed primarily for DI in tests.
* @property key The API key used for authentication.
* @property model The model to use for generation.
* @property apiVersion the endpoint version to communicate with.
* @property timeout the maximum amount of time for a request to take in the initial exchange.
*/
internal class APIController(
private val key: String,
model: String,
httpEngine: HttpClientEngine = OkHttp.create()
private val apiVersion: String,
private val timeout: Duration,
httpEngine: HttpClientEngine = OkHttp.create(),
) {
private val model = fullModelName(model)

private val client =
HttpClient(httpEngine) {
install(HttpTimeout) {
requestTimeoutMillis = HttpTimeout.INFINITE_TIMEOUT_MS
requestTimeoutMillis = timeout.inWholeMilliseconds
socketTimeoutMillis = 80_000
}
install(ContentNegotiation) { json(JSON) }
}

suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse {
return client
.post("$DOMAIN/$model:generateContent") { applyCommonConfiguration(request) }
suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse =
client
.post("$DOMAIN/$apiVersion/$model:generateContent") { applyCommonConfiguration(request) }
.also { validateResponse(it) }
.body()
}

fun generateContentStream(request: GenerateContentRequest): Flow<GenerateContentResponse> {
return client.postStream("$DOMAIN/$model:streamGenerateContent?alt=sse") {
return client.postStream<GenerateContentResponse>(
"$DOMAIN/$apiVersion/$model:streamGenerateContent?alt=sse"
) {
applyCommonConfiguration(request)
}
}

suspend fun countTokens(request: CountTokensRequest): CountTokensResponse {
return client
.post("$DOMAIN/$model:countTokens") { applyCommonConfiguration(request) }
suspend fun countTokens(request: CountTokensRequest): CountTokensResponse =
client
.post("$DOMAIN/$apiVersion/$model:countTokens") { applyCommonConfiguration(request) }
.also { validateResponse(it) }
.body()
}

private fun HttpRequestBuilder.applyCommonConfiguration(request: Request) {
when (request) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.google.ai.client.generativeai.type

import com.google.ai.client.generativeai.GenerativeModel
import io.ktor.serialization.JsonConvertException
import kotlinx.coroutines.TimeoutCancellationException

/** Parent class for any errors that occur from [GenerativeModel]. */
sealed class GoogleGenerativeAIException(message: String, cause: Throwable? = null) :
Expand All @@ -39,6 +40,8 @@ sealed class GoogleGenerativeAIException(message: String, cause: Throwable? = nu
"Something went wrong while trying to deserialize a response from the server.",
cause
)
is TimeoutCancellationException ->
RequestTimeoutException("The request failed to complete in the allotted time.")
else -> UnknownException("Something unexpected happened.", cause)
}
}
Expand Down Expand Up @@ -84,6 +87,14 @@ class ResponseStoppedException(val response: GenerateContentResponse, cause: Thr
cause
)

/**
* A request took too long to complete.
*
* Usually occurs due to a user specified [timeout][RequestOptions.timeout].
*/
class RequestTimeoutException(message: String, cause: Throwable? = null) :
GoogleGenerativeAIException(message, cause)

/** Catch all case for exceptions not explicitly expected. */
class UnknownException(message: String, cause: Throwable? = null) :
GoogleGenerativeAIException(message, cause)
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.ai.client.generativeai.type

import io.ktor.client.plugins.HttpTimeout
import kotlin.time.Duration
import kotlin.time.DurationUnit
import kotlin.time.toDuration

/**
* Configurable options unique to how requests to the backend are performed.
*
* @property timeout the maximum amount of time for a request to take, from the first request to
* first response.
* @property apiVersion the api endpoint to call.
*/
class RequestOptions(val timeout: Duration, val apiVersion: String = "v1") {
@JvmOverloads
constructor(
timeout: Long? = HttpTimeout.INFINITE_TIMEOUT_MS,
apiVersion: String = "v1"
) : this(
(timeout ?: HttpTimeout.INFINITE_TIMEOUT_MS).toDuration(DurationUnit.MILLISECONDS),
apiVersion
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@

package com.google.ai.client.generativeai

import com.google.ai.client.generativeai.type.RequestOptions
import com.google.ai.client.generativeai.type.RequestTimeoutException
import com.google.ai.client.generativeai.util.commonTest
import com.google.ai.client.generativeai.util.createResponses
import com.google.ai.client.generativeai.util.prepareStreamingResponse
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.shouldBe
import io.ktor.utils.io.close
import io.ktor.utils.io.writeFully
Expand All @@ -45,4 +48,12 @@ internal class GenerativeModelTests {
}
}
}

@Test
fun `(generateContent) respects a custom timeout`() =
commonTest(requestOptions = RequestOptions(2.seconds)) {
shouldThrow<RequestTimeoutException> {
withTimeout(testTimeout) { model.generateContent("d") }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import com.google.ai.client.generativeai.internal.api.shared.Content
import com.google.ai.client.generativeai.internal.api.shared.TextPart
import com.google.ai.client.generativeai.internal.util.SSE_SEPARATOR
import com.google.ai.client.generativeai.internal.util.send
import com.google.ai.client.generativeai.type.RequestOptions
import io.ktor.client.engine.mock.MockEngine
import io.ktor.client.engine.mock.respond
import io.ktor.http.HttpHeaders
Expand Down Expand Up @@ -93,19 +94,30 @@ internal typealias CommonTest = suspend CommonTestScope.() -> Unit
* ```
*
* @param status An optional [HttpStatusCode] to return as a response
* @param requestOptions Optional [RequestOptions] to utilize in the underlying controller
* @param block The test contents themselves, with the [CommonTestScope] implicitly provided
* @see CommonTestScope
*/
internal fun commonTest(status: HttpStatusCode = HttpStatusCode.OK, block: CommonTest) =
doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json"))
}
val controller = APIController("super_cool_test_key", "gemini-pro", mockEngine)
val model = GenerativeModel("gemini-pro", "super_cool_test_key", controller = controller)
CommonTestScope(channel, model).block()
internal fun commonTest(
status: HttpStatusCode = HttpStatusCode.OK,
requestOptions: RequestOptions = RequestOptions(),
block: CommonTest
) = doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json"))
}
val controller =
APIController(
"super_cool_test_key",
"gemini-pro",
requestOptions.apiVersion,
requestOptions.timeout,
mockEngine
)
val model = GenerativeModel("gemini-pro", "super_cool_test_key", controller = controller)
CommonTestScope(channel, model).block()
}

/**
* A variant of [commonTest] for performing *streaming-based* snapshot tests.
Expand Down
Loading