Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/dev-0.1.3' into modelname.feature
Browse files Browse the repository at this point in the history
  • Loading branch information
rlazo committed Feb 12, 2024
2 parents 7ab7851 + fff522d commit 073a575
Show file tree
Hide file tree
Showing 11 changed files with 130 additions and 28 deletions.
1 change: 1 addition & 0 deletions .changes/behavior-attention-channel-bottle.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"MINOR","changes":["Add RequestOptions; configuration points for backend implementation details such as api version and timeout."]}
1 change: 1 addition & 0 deletions .changes/crowd-comfort-color-burst.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"PATCH","changes":["GenerativeModelFutures now correctly passes the prompt through"]}
2 changes: 1 addition & 1 deletion generativeai/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ dependencies {

implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1")
implementation("androidx.core:core-ktx:1.12.0")
implementation("org.slf4j:slf4j-android:1.7.36")
implementation("org.slf4j:slf4j-nop:2.0.9")
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.7.3")
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-reactive:1.7.3")
implementation("org.reactivestreams:reactive-streams:1.0.3")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import com.google.ai.client.generativeai.type.GenerateContentResponse
import com.google.ai.client.generativeai.type.GenerationConfig
import com.google.ai.client.generativeai.type.GoogleGenerativeAIException
import com.google.ai.client.generativeai.type.PromptBlockedException
import com.google.ai.client.generativeai.type.RequestOptions
import com.google.ai.client.generativeai.type.ResponseStoppedException
import com.google.ai.client.generativeai.type.SafetySetting
import com.google.ai.client.generativeai.type.SerializationException
Expand All @@ -45,13 +46,15 @@ import kotlinx.coroutines.flow.map
* @property generationConfig configuration parameters to use for content generation
* @property safetySettings the safety bounds to use during alongside prompts during content
* generation
* @property requestOptions configuration options to utilize during backend communication
*/
class GenerativeModel
internal constructor(
val modelName: String,
val apiKey: String,
val generationConfig: GenerationConfig? = null,
val safetySettings: List<SafetySetting>? = null,
val requestOptions: RequestOptions = RequestOptions(),
private val controller: APIController
) {

Expand All @@ -61,7 +64,15 @@ internal constructor(
apiKey: String,
generationConfig: GenerationConfig? = null,
safetySettings: List<SafetySetting>? = null,
) : this(modelName, apiKey, generationConfig, safetySettings, APIController(apiKey, modelName))
requestOptions: RequestOptions = RequestOptions(),
) : this(
modelName,
apiKey,
generationConfig,
safetySettings,
requestOptions,
APIController(apiKey, modelName, requestOptions.apiVersion, requestOptions.timeout)
)

/**
* Generates a response from the backend with the provided [Content]s.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,15 @@ import io.ktor.http.ContentType
import io.ktor.http.HttpStatusCode
import io.ktor.http.contentType
import io.ktor.serialization.kotlinx.json.json
import kotlin.time.Duration
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.channelFlow
import kotlinx.coroutines.flow.timeout
import kotlinx.coroutines.launch
import kotlinx.serialization.json.Json

// TODO: Should these stay here or be moved elsewhere?
internal const val DOMAIN = "https://generativelanguage.googleapis.com/v1"
internal const val DOMAIN = "https://generativelanguage.googleapis.com"

internal val JSON = Json {
ignoreUnknownKeys = true
Expand All @@ -60,42 +61,46 @@ internal val JSON = Json {
* Exposed primarily for DI in tests.
* @property key The API key used for authentication.
* @property model The model to use for generation.
* @property apiVersion the endpoint version to communicate with.
* @property timeout the maximum amount of time for a request to take in the initial exchange.
*/
internal class APIController(
private val key: String,
model: String,
httpEngine: HttpClientEngine = OkHttp.create()
private val apiVersion: String,
private val timeout: Duration,
httpEngine: HttpClientEngine = OkHttp.create(),
) {
private val model = fullModelName(model)

private val client =
HttpClient(httpEngine) {
install(HttpTimeout) {
requestTimeoutMillis = HttpTimeout.INFINITE_TIMEOUT_MS
requestTimeoutMillis = timeout.inWholeMilliseconds
socketTimeoutMillis = 80_000
}
install(ContentNegotiation) { json(JSON) }
}

suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse {
return client
.post("$DOMAIN/$model:generateContent") { applyCommonConfiguration(request) }
suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse =
client
.post("$DOMAIN/$apiVersion/$model:generateContent") { applyCommonConfiguration(request) }
.also { validateResponse(it) }
.body()
}

fun generateContentStream(request: GenerateContentRequest): Flow<GenerateContentResponse> {
return client.postStream("$DOMAIN/$model:streamGenerateContent?alt=sse") {
return client.postStream<GenerateContentResponse>(
"$DOMAIN/$apiVersion/$model:streamGenerateContent?alt=sse"
) {
applyCommonConfiguration(request)
}
}

suspend fun countTokens(request: CountTokensRequest): CountTokensResponse {
return client
.post("$DOMAIN/$model:countTokens") { applyCommonConfiguration(request) }
suspend fun countTokens(request: CountTokensRequest): CountTokensResponse =
client
.post("$DOMAIN/$apiVersion/$model:countTokens") { applyCommonConfiguration(request) }
.also { validateResponse(it) }
.body()
}

private fun HttpRequestBuilder.applyCommonConfiguration(request: Request) {
when (request) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ abstract class GenerativeModelFutures internal constructor() {
SuspendToFutureAdapter.launchFuture { model.generateContent(*prompt) }

override fun generateContentStream(vararg prompt: Content): Publisher<GenerateContentResponse> =
model.generateContentStream().asPublisher()
model.generateContentStream(*prompt).asPublisher()

override fun countTokens(vararg prompt: Content): ListenableFuture<CountTokensResponse> =
SuspendToFutureAdapter.launchFuture { model.countTokens(*prompt) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.google.ai.client.generativeai.type

import com.google.ai.client.generativeai.GenerativeModel
import io.ktor.serialization.JsonConvertException
import kotlinx.coroutines.TimeoutCancellationException

/** Parent class for any errors that occur from [GenerativeModel]. */
sealed class GoogleGenerativeAIException(message: String, cause: Throwable? = null) :
Expand All @@ -39,6 +40,8 @@ sealed class GoogleGenerativeAIException(message: String, cause: Throwable? = nu
"Something went wrong while trying to deserialize a response from the server.",
cause
)
is TimeoutCancellationException ->
RequestTimeoutException("The request failed to complete in the allotted time.")
else -> UnknownException("Something unexpected happened.", cause)
}
}
Expand Down Expand Up @@ -84,6 +87,14 @@ class ResponseStoppedException(val response: GenerateContentResponse, cause: Thr
cause
)

/**
* A request took too long to complete.
*
* Usually occurs due to a user specified [timeout][RequestOptions.timeout].
*/
class RequestTimeoutException(message: String, cause: Throwable? = null) :
GoogleGenerativeAIException(message, cause)

/** Catch all case for exceptions not explicitly expected. */
class UnknownException(message: String, cause: Throwable? = null) :
GoogleGenerativeAIException(message, cause)
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.ai.client.generativeai.type

import io.ktor.client.plugins.HttpTimeout
import kotlin.time.Duration
import kotlin.time.DurationUnit
import kotlin.time.toDuration

/**
* Configurable options unique to how requests to the backend are performed.
*
* @property timeout the maximum amount of time for a request to take, from the first request to
* first response.
* @property apiVersion the api endpoint to call.
*/
class RequestOptions(val timeout: Duration, val apiVersion: String = "v1") {
@JvmOverloads
constructor(
timeout: Long? = HttpTimeout.INFINITE_TIMEOUT_MS,
apiVersion: String = "v1"
) : this(
(timeout ?: HttpTimeout.INFINITE_TIMEOUT_MS).toDuration(DurationUnit.MILLISECONDS),
apiVersion
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@

package com.google.ai.client.generativeai

import com.google.ai.client.generativeai.type.RequestOptions
import com.google.ai.client.generativeai.type.RequestTimeoutException
import com.google.ai.client.generativeai.util.commonTest
import com.google.ai.client.generativeai.util.createGenerativeModel
import com.google.ai.client.generativeai.util.createResponses
import com.google.ai.client.generativeai.util.doBlocking
import com.google.ai.client.generativeai.util.prepareStreamingResponse
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.shouldBe
import io.kotest.matchers.string.shouldContain
import io.ktor.client.engine.mock.MockEngine
Expand Down Expand Up @@ -55,6 +58,14 @@ internal class GenerativeModelTests {
}
}
}

@Test
fun `(generateContent) respects a custom timeout`() =
commonTest(requestOptions = RequestOptions(2.seconds)) {
shouldThrow<RequestTimeoutException> {
withTimeout(testTimeout) { model.generateContent("d") }
}
}
}

@RunWith(Parameterized::class)
Expand All @@ -67,7 +78,8 @@ internal class ModelNamingTests(private val modelName: String, private val actua
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
val model = createGenerativeModel(modelName, "super_cool_test_key", mockEngine)
val model =
createGenerativeModel(modelName, "super_cool_test_key", RequestOptions(), mockEngine)

withTimeout(5.seconds) {
model.generateContentStream().collect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import com.google.ai.client.generativeai.internal.api.shared.Content
import com.google.ai.client.generativeai.internal.api.shared.TextPart
import com.google.ai.client.generativeai.internal.util.SSE_SEPARATOR
import com.google.ai.client.generativeai.internal.util.send
import com.google.ai.client.generativeai.type.RequestOptions
import io.ktor.client.engine.mock.MockEngine
import io.ktor.client.engine.mock.respond
import io.ktor.http.HttpHeaders
Expand Down Expand Up @@ -93,22 +94,42 @@ internal typealias CommonTest = suspend CommonTestScope.() -> Unit
* ```
*
* @param status An optional [HttpStatusCode] to return as a response
* @param requestOptions Optional [RequestOptions] to utilize in the underlying controller
* @param block The test contents themselves, with the [CommonTestScope] implicitly provided
* @see CommonTestScope
*/
internal fun commonTest(status: HttpStatusCode = HttpStatusCode.OK, block: CommonTest) =
doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json"))
}
val model = createGenerativeModel("gemini-pro", "super_cool_test_key", mockEngine)
CommonTestScope(channel, model).block()
internal fun commonTest(
status: HttpStatusCode = HttpStatusCode.OK,
requestOptions: RequestOptions = RequestOptions(),
block: CommonTest
) = doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json"))
}
val model = createGenerativeModel("gemini-pro", "super_cool_test_key", requestOptions, mockEngine)
CommonTestScope(channel, model).block()
}

/** Simple wrapper that guarantees the model and APIController are created using the same data */
internal fun createGenerativeModel(name: String, apikey: String, engine: MockEngine) =
GenerativeModel(name, apikey, controller = APIController(apikey, name, engine))
internal fun createGenerativeModel(
name: String,
apikey: String,
requestOptions: RequestOptions = RequestOptions(),
engine: MockEngine
) =
GenerativeModel(
name,
apikey,
controller =
APIController(
"super_cool_test_key",
name,
requestOptions.apiVersion,
requestOptions.timeout,
engine
)
)

/**
* A variant of [commonTest] for performing *streaming-based* snapshot tests.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ abstract class ApiPlugin : Plugin<Project> {

context(Project)
private fun ApiPluginExtension.commonConfiguration() {
val latestApiFile = project.file("api/${project.version}.api")
val latestApiFile = rootProject.file("api/${project.version}.api")

apiFile.convention(latestApiFile)
}
Expand Down

0 comments on commit 073a575

Please sign in to comment.