Skip to content

Commit

Permalink
Merge dev branch (#57)
Browse files Browse the repository at this point in the history
Co-authored-by: Daymon <[email protected]>
  • Loading branch information
rlazo and daymxn authored Feb 13, 2024
1 parent dc21ffd commit 87fffa9
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 27 deletions.
1 change: 1 addition & 0 deletions .changes/behavior-attention-channel-bottle.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"MINOR","changes":["Add RequestOptions; configuration points for backend implementation details such as api version and timeout."]}
1 change: 1 addition & 0 deletions .changes/condition-company-cloth-distribution.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"MAJOR","changes":["Support a general model naming schema"]}
3 changes: 2 additions & 1 deletion generativeai/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ dependencies {

implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1")
implementation("androidx.core:core-ktx:1.12.0")
implementation("org.slf4j:slf4j-android:1.7.36")
implementation("org.slf4j:slf4j-nop:2.0.9")
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.7.3")
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-reactive:1.7.3")
implementation("org.reactivestreams:reactive-streams:1.0.3")
Expand All @@ -94,6 +94,7 @@ dependencies {
testImplementation("junit:junit:4.13.2")
testImplementation("io.kotest:kotest-assertions-core:4.0.7")
testImplementation("io.kotest:kotest-assertions-jvm:4.0.7")
testImplementation("io.kotest:kotest-assertions-json:4.0.7")
testImplementation("io.ktor:ktor-client-mock:$ktorVersion")
androidTestImplementation("androidx.test.ext:junit:1.1.5")
androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import com.google.ai.client.generativeai.type.GenerateContentResponse
import com.google.ai.client.generativeai.type.GenerationConfig
import com.google.ai.client.generativeai.type.GoogleGenerativeAIException
import com.google.ai.client.generativeai.type.PromptBlockedException
import com.google.ai.client.generativeai.type.RequestOptions
import com.google.ai.client.generativeai.type.ResponseStoppedException
import com.google.ai.client.generativeai.type.SafetySetting
import com.google.ai.client.generativeai.type.SerializationException
Expand All @@ -45,13 +46,15 @@ import kotlinx.coroutines.flow.map
* @property generationConfig configuration parameters to use for content generation
* @property safetySettings the safety bounds to use during alongside prompts during content
* generation
* @property requestOptions configuration options to utilize during backend communication
*/
class GenerativeModel
internal constructor(
val modelName: String,
val apiKey: String,
val generationConfig: GenerationConfig? = null,
val safetySettings: List<SafetySetting>? = null,
val requestOptions: RequestOptions = RequestOptions(),
private val controller: APIController
) {

Expand All @@ -61,7 +64,15 @@ internal constructor(
apiKey: String,
generationConfig: GenerationConfig? = null,
safetySettings: List<SafetySetting>? = null,
) : this(modelName, apiKey, generationConfig, safetySettings, APIController(apiKey, modelName))
requestOptions: RequestOptions = RequestOptions(),
) : this(
modelName,
apiKey,
generationConfig,
safetySettings,
requestOptions,
APIController(apiKey, modelName, requestOptions.apiVersion, requestOptions.timeout)
)

/**
* Generates a response from the backend with the provided [Content]s.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,15 @@ import io.ktor.http.ContentType
import io.ktor.http.HttpStatusCode
import io.ktor.http.contentType
import io.ktor.serialization.kotlinx.json.json
import kotlin.time.Duration
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.channelFlow
import kotlinx.coroutines.flow.timeout
import kotlinx.coroutines.launch
import kotlinx.serialization.json.Json

// TODO: Should these stay here or be moved elsewhere?
internal const val DOMAIN = "https://generativelanguage.googleapis.com/v1"
internal const val DOMAIN = "https://generativelanguage.googleapis.com"

internal val JSON = Json {
ignoreUnknownKeys = true
Expand All @@ -60,42 +61,46 @@ internal val JSON = Json {
* Exposed primarily for DI in tests.
* @property key The API key used for authentication.
* @property model The model to use for generation.
* @property apiVersion the endpoint version to communicate with.
* @property timeout the maximum amount of time for a request to take in the initial exchange.
*/
internal class APIController(
private val key: String,
model: String,
httpEngine: HttpClientEngine = OkHttp.create()
private val apiVersion: String,
private val timeout: Duration,
httpEngine: HttpClientEngine = OkHttp.create(),
) {
private val model = fullModelName(model)

private val client =
HttpClient(httpEngine) {
install(HttpTimeout) {
requestTimeoutMillis = HttpTimeout.INFINITE_TIMEOUT_MS
requestTimeoutMillis = timeout.inWholeMilliseconds
socketTimeoutMillis = 80_000
}
install(ContentNegotiation) { json(JSON) }
}

suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse {
return client
.post("$DOMAIN/$model:generateContent") { applyCommonConfiguration(request) }
suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse =
client
.post("$DOMAIN/$apiVersion/$model:generateContent") { applyCommonConfiguration(request) }
.also { validateResponse(it) }
.body()
}

fun generateContentStream(request: GenerateContentRequest): Flow<GenerateContentResponse> {
return client.postStream("$DOMAIN/$model:streamGenerateContent?alt=sse") {
return client.postStream<GenerateContentResponse>(
"$DOMAIN/$apiVersion/$model:streamGenerateContent?alt=sse"
) {
applyCommonConfiguration(request)
}
}

suspend fun countTokens(request: CountTokensRequest): CountTokensResponse {
return client
.post("$DOMAIN/$model:countTokens") { applyCommonConfiguration(request) }
suspend fun countTokens(request: CountTokensRequest): CountTokensResponse =
client
.post("$DOMAIN/$apiVersion/$model:countTokens") { applyCommonConfiguration(request) }
.also { validateResponse(it) }
.body()
}

private fun HttpRequestBuilder.applyCommonConfiguration(request: Request) {
when (request) {
Expand All @@ -113,8 +118,7 @@ internal class APIController(
*
* Models must be prepended with the `models/` prefix when communicating with the backend.
*/
private fun fullModelName(name: String): String =
name.takeIf { it.startsWith("models/") } ?: "models/$name"
private fun fullModelName(name: String): String = name.takeIf { it.contains("/") } ?: "models/$name"

/**
* Makes a POST request to the specified [url] and returns a [Flow] of deserialized response objects
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.google.ai.client.generativeai.type

import com.google.ai.client.generativeai.GenerativeModel
import io.ktor.serialization.JsonConvertException
import kotlinx.coroutines.TimeoutCancellationException

/** Parent class for any errors that occur from [GenerativeModel]. */
sealed class GoogleGenerativeAIException(message: String, cause: Throwable? = null) :
Expand All @@ -39,6 +40,8 @@ sealed class GoogleGenerativeAIException(message: String, cause: Throwable? = nu
"Something went wrong while trying to deserialize a response from the server.",
cause
)
is TimeoutCancellationException ->
RequestTimeoutException("The request failed to complete in the allotted time.")
else -> UnknownException("Something unexpected happened.", cause)
}
}
Expand Down Expand Up @@ -84,6 +87,14 @@ class ResponseStoppedException(val response: GenerateContentResponse, cause: Thr
cause
)

/**
* A request took too long to complete.
*
* Usually occurs due to a user specified [timeout][RequestOptions.timeout].
*/
class RequestTimeoutException(message: String, cause: Throwable? = null) :
GoogleGenerativeAIException(message, cause)

/** Catch all case for exceptions not explicitly expected. */
class UnknownException(message: String, cause: Throwable? = null) :
GoogleGenerativeAIException(message, cause)
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.ai.client.generativeai.type

import io.ktor.client.plugins.HttpTimeout
import kotlin.time.Duration
import kotlin.time.DurationUnit
import kotlin.time.toDuration

/**
* Configurable options unique to how requests to the backend are performed.
*
* @property timeout the maximum amount of time for a request to take, from the first request to
* first response.
* @property apiVersion the api endpoint to call.
*/
class RequestOptions(val timeout: Duration, val apiVersion: String = "v1") {
@JvmOverloads
constructor(
timeout: Long? = HttpTimeout.INFINITE_TIMEOUT_MS,
apiVersion: String = "v1"
) : this(
(timeout ?: HttpTimeout.INFINITE_TIMEOUT_MS).toDuration(DurationUnit.MILLISECONDS),
apiVersion
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,29 @@

package com.google.ai.client.generativeai

import com.google.ai.client.generativeai.type.RequestOptions
import com.google.ai.client.generativeai.type.RequestTimeoutException
import com.google.ai.client.generativeai.util.commonTest
import com.google.ai.client.generativeai.util.createGenerativeModel
import com.google.ai.client.generativeai.util.createResponses
import com.google.ai.client.generativeai.util.doBlocking
import com.google.ai.client.generativeai.util.prepareStreamingResponse
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.shouldBe
import io.kotest.matchers.string.shouldContain
import io.ktor.client.engine.mock.MockEngine
import io.ktor.client.engine.mock.respond
import io.ktor.http.HttpHeaders
import io.ktor.http.HttpStatusCode
import io.ktor.http.headersOf
import io.ktor.utils.io.ByteChannel
import io.ktor.utils.io.close
import io.ktor.utils.io.writeFully
import kotlin.time.Duration.Companion.seconds
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.withTimeout
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized

internal class GenerativeModelTests {
private val testTimeout = 5.seconds
Expand All @@ -45,4 +58,49 @@ internal class GenerativeModelTests {
}
}
}

@Test
fun `(generateContent) respects a custom timeout`() =
commonTest(requestOptions = RequestOptions(2.seconds)) {
shouldThrow<RequestTimeoutException> {
withTimeout(testTimeout) { model.generateContent("d") }
}
}
}

@RunWith(Parameterized::class)
internal class ModelNamingTests(private val modelName: String, private val actualName: String) {

@Test
fun `request should include right model name`() = doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
val model =
createGenerativeModel(modelName, "super_cool_test_key", RequestOptions(), mockEngine)

withTimeout(5.seconds) {
model.generateContentStream().collect {
it.candidates.isEmpty() shouldBe false
channel.close()
}
}

mockEngine.requestHistory.first().url.encodedPath shouldContain actualName
}

companion object {
@JvmStatic
@Parameterized.Parameters
fun data() =
listOf(
arrayOf("gemini-pro", "models/gemini-pro"),
arrayOf("x/gemini-pro", "x/gemini-pro"),
arrayOf("models/gemini-pro", "models/gemini-pro"),
arrayOf("/modelname", "/modelname"),
arrayOf("modifiedNaming/mymodel", "modifiedNaming/mymodel"),
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import com.google.ai.client.generativeai.internal.api.shared.Content
import com.google.ai.client.generativeai.internal.api.shared.TextPart
import com.google.ai.client.generativeai.internal.util.SSE_SEPARATOR
import com.google.ai.client.generativeai.internal.util.send
import com.google.ai.client.generativeai.type.RequestOptions
import io.ktor.client.engine.mock.MockEngine
import io.ktor.client.engine.mock.respond
import io.ktor.http.HttpHeaders
Expand Down Expand Up @@ -93,19 +94,42 @@ internal typealias CommonTest = suspend CommonTestScope.() -> Unit
* ```
*
* @param status An optional [HttpStatusCode] to return as a response
* @param requestOptions Optional [RequestOptions] to utilize in the underlying controller
* @param block The test contents themselves, with the [CommonTestScope] implicitly provided
* @see CommonTestScope
*/
internal fun commonTest(status: HttpStatusCode = HttpStatusCode.OK, block: CommonTest) =
doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json"))
}
val controller = APIController("super_cool_test_key", "gemini-pro", mockEngine)
val model = GenerativeModel("gemini-pro", "super_cool_test_key", controller = controller)
CommonTestScope(channel, model).block()
internal fun commonTest(
status: HttpStatusCode = HttpStatusCode.OK,
requestOptions: RequestOptions = RequestOptions(),
block: CommonTest
) = doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json"))
}
val model = createGenerativeModel("gemini-pro", "super_cool_test_key", requestOptions, mockEngine)
CommonTestScope(channel, model).block()
}

/** Simple wrapper that guarantees the model and APIController are created using the same data */
internal fun createGenerativeModel(
name: String,
apikey: String,
requestOptions: RequestOptions = RequestOptions(),
engine: MockEngine
) =
GenerativeModel(
name,
apikey,
controller =
APIController(
"super_cool_test_key",
name,
requestOptions.apiVersion,
requestOptions.timeout,
engine
)
)

/**
* A variant of [commonTest] for performing *streaming-based* snapshot tests.
Expand Down

0 comments on commit 87fffa9

Please sign in to comment.