From b8857dd5f3f899e570ac719716ad49bbb500150e Mon Sep 17 00:00:00 2001 From: Rodrigo Lazo Date: Fri, 15 Mar 2024 13:02:02 -0400 Subject: [PATCH] Initial commit of common (#84) Changes are the minimum necessary to extract the common functionality into the new sdk, mostly renaming packages, removing internal declarations, and that's it. Follow up CLs will add tests, remove data classes, and improve exception reporting. --- common/.gitignore | 1 + common/build.gradle.kts | 126 ++++++++++++ common/consumer-rules.pro | 21 ++ common/gradle.properties | 1 + common/proguard-rules.pro | 21 ++ common/src/main/AndroidManifest.xml | 18 ++ .../generativeai/common/APIController.kt | 185 ++++++++++++++++++ .../client/generativeai/common/Exceptions.kt | 113 +++++++++++ .../ai/client/generativeai/common/Request.kt | 36 ++++ .../generativeai/common/RequestOptions.kt | 40 ++++ .../ai/client/generativeai/common/Response.kt | 34 ++++ .../generativeai/common/client/Types.kt | 30 +++ .../generativeai/common/server/Types.kt | 104 ++++++++++ .../generativeai/common/shared/Types.kt | 82 ++++++++ .../client/generativeai/common/util/kotlin.kt | 41 ++++ .../client/generativeai/common/util/ktor.kt | 101 ++++++++++ .../generativeai/common/util/serialization.kt | 83 ++++++++ settings.gradle.kts | 1 + 18 files changed, 1038 insertions(+) create mode 100644 common/.gitignore create mode 100644 common/build.gradle.kts create mode 100644 common/consumer-rules.pro create mode 100644 common/gradle.properties create mode 100644 common/proguard-rules.pro create mode 100644 common/src/main/AndroidManifest.xml create mode 100644 common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt create mode 100644 common/src/main/kotlin/com/google/ai/client/generativeai/common/Exceptions.kt create mode 100644 common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt create mode 100644 common/src/main/kotlin/com/google/ai/client/generativeai/common/RequestOptions.kt create mode 100644 common/src/main/kotlin/com/google/ai/client/generativeai/common/Response.kt create mode 100644 common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt create mode 100644 common/src/main/kotlin/com/google/ai/client/generativeai/common/server/Types.kt create mode 100644 common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt create mode 100644 common/src/main/kotlin/com/google/ai/client/generativeai/common/util/kotlin.kt create mode 100644 common/src/main/kotlin/com/google/ai/client/generativeai/common/util/ktor.kt create mode 100644 common/src/main/kotlin/com/google/ai/client/generativeai/common/util/serialization.kt diff --git a/common/.gitignore b/common/.gitignore new file mode 100644 index 00000000..796b96d1 --- /dev/null +++ b/common/.gitignore @@ -0,0 +1 @@ +/build diff --git a/common/build.gradle.kts b/common/build.gradle.kts new file mode 100644 index 00000000..bd02f5fa --- /dev/null +++ b/common/build.gradle.kts @@ -0,0 +1,126 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { + id("com.android.library") + id("maven-publish") + id("com.ncorti.ktfmt.gradle") + id("changelog-plugin") + id("release-plugin") + kotlin("android") + kotlin("plugin.serialization") +} + +ktfmt { + googleStyle() +} + +android { + namespace = "com.google.ai.client.generativeai.common" + compileSdk = 34 + + buildFeatures.buildConfig = true + + defaultConfig { + minSdk = 21 + + testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner" + consumerProguardFiles("consumer-rules.pro") + + buildConfigField("String", "VERSION_NAME", "\"${project.version.toString()}\"") + } + + publishing { + singleVariant("release") { + withSourcesJar() + } + } + + buildTypes { + release { + isMinifyEnabled = false + proguardFiles( + getDefaultProguardFile("proguard-android-optimize.txt"), + "proguard-rules.pro" + ) + } + } + compileOptions { + sourceCompatibility = JavaVersion.VERSION_17 + targetCompatibility = JavaVersion.VERSION_17 + } + kotlinOptions { + jvmTarget = "17" + } + + testOptions { + unitTests.isReturnDefaultValues = true + } +} + +dependencies { + val ktorVersion = "2.3.2" + + implementation("io.ktor:ktor-client-okhttp:$ktorVersion") + implementation("io.ktor:ktor-client-core:$ktorVersion") + implementation("io.ktor:ktor-client-content-negotiation:$ktorVersion") + implementation("io.ktor:ktor-serialization-kotlinx-json:$ktorVersion") + implementation("io.ktor:ktor-client-logging:$ktorVersion") + + implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1") + implementation("androidx.core:core-ktx:1.12.0") + implementation("org.slf4j:slf4j-nop:2.0.9") + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.7.3") + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-reactive:1.7.3") + implementation("org.reactivestreams:reactive-streams:1.0.3") + + implementation("com.google.guava:listenablefuture:1.0") + implementation("androidx.concurrent:concurrent-futures:1.2.0-alpha02") + implementation("androidx.concurrent:concurrent-futures-ktx:1.2.0-alpha02") + testImplementation("junit:junit:4.13.2") + testImplementation("io.kotest:kotest-assertions-core:4.0.7") + testImplementation("io.kotest:kotest-assertions-jvm:4.0.7") + testImplementation("io.kotest:kotest-assertions-json:4.0.7") + testImplementation("io.ktor:ktor-client-mock:$ktorVersion") + androidTestImplementation("androidx.test.ext:junit:1.1.5") + androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1") +} + +publishing { + publications { + register("release") { + groupId = "com.google.ai.client.generativeai" + artifactId = "common" + version = project.version.toString() + pom { + licenses { + license { + name = "The Apache License, Version 2.0" + url = "http://www.apache.org/licenses/LICENSE-2.0.txt" + } + } + } + afterEvaluate { + from(components["release"]) + } + } + } + repositories { + maven { + url = uri("${projectDir}/m2") + } + } +} diff --git a/common/consumer-rules.pro b/common/consumer-rules.pro new file mode 100644 index 00000000..f1b42451 --- /dev/null +++ b/common/consumer-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/common/gradle.properties b/common/gradle.properties new file mode 100644 index 00000000..de55ab65 --- /dev/null +++ b/common/gradle.properties @@ -0,0 +1 @@ +version=0.1.0 diff --git a/common/proguard-rules.pro b/common/proguard-rules.pro new file mode 100644 index 00000000..f1b42451 --- /dev/null +++ b/common/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/common/src/main/AndroidManifest.xml b/common/src/main/AndroidManifest.xml new file mode 100644 index 00000000..7ebb9c28 --- /dev/null +++ b/common/src/main/AndroidManifest.xml @@ -0,0 +1,18 @@ + + + + + diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt new file mode 100644 index 00000000..57e5413a --- /dev/null +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt @@ -0,0 +1,185 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.common + +import com.google.ai.client.generativeai.common.util.decodeToFlow +import io.ktor.client.HttpClient +import io.ktor.client.call.body +import io.ktor.client.engine.HttpClientEngine +import io.ktor.client.engine.okhttp.OkHttp +import io.ktor.client.plugins.HttpTimeout +import io.ktor.client.plugins.contentnegotiation.ContentNegotiation +import io.ktor.client.request.HttpRequestBuilder +import io.ktor.client.request.header +import io.ktor.client.request.post +import io.ktor.client.request.preparePost +import io.ktor.client.request.setBody +import io.ktor.client.statement.HttpResponse +import io.ktor.client.statement.bodyAsChannel +import io.ktor.client.statement.bodyAsText +import io.ktor.http.ContentType +import io.ktor.http.HttpStatusCode +import io.ktor.http.contentType +import io.ktor.serialization.kotlinx.json.json +import kotlinx.coroutines.CoroutineName +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.channelFlow +import kotlinx.coroutines.flow.timeout +import kotlinx.coroutines.launch +import kotlinx.serialization.json.Json + +const val DOMAIN = "https://generativelanguage.googleapis.com" + +val JSON = Json { + ignoreUnknownKeys = true + prettyPrint = false +} + +/** + * Backend class for interfacing with the Gemini API. + * + * This class handles making HTTP requests to the API and streaming the responses back. + * + * @param httpEngine The HTTP client engine to be used for making requests. Defaults to CIO engine. + * Exposed primarily for DI in tests. + * @property key The API key used for authentication. + * @property model The model to use for generation. + * @property apiVersion the endpoint version to communicate with. + * @property timeout the maximum amount of time for a request to take in the initial exchange. + */ +class APIController( + private val key: String, + model: String, + private val requestOptions: RequestOptions, + httpEngine: HttpClientEngine = OkHttp.create(), +) { + private val model = fullModelName(model) + + private val client = + HttpClient(httpEngine) { + install(HttpTimeout) { + requestTimeoutMillis = requestOptions.timeout.inWholeMilliseconds + socketTimeoutMillis = 80_000 + } + install(ContentNegotiation) { json(JSON) } + } + + suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse = + client + .post("$DOMAIN/${requestOptions.apiVersion}/$model:generateContent") { + applyCommonConfiguration(request) + } + .also { validateResponse(it) } + .body() + + fun generateContentStream(request: GenerateContentRequest): Flow { + return client.postStream( + "$DOMAIN/${requestOptions.apiVersion}/$model:streamGenerateContent?alt=sse" + ) { + applyCommonConfiguration(request) + } + } + + suspend fun countTokens(request: CountTokensRequest): CountTokensResponse = + client + .post("$DOMAIN/${requestOptions.apiVersion}/$model:countTokens") { + applyCommonConfiguration(request) + } + .also { validateResponse(it) } + .body() + + private fun HttpRequestBuilder.applyCommonConfiguration(request: Request) { + when (request) { + is GenerateContentRequest -> setBody(request) + is CountTokensRequest -> setBody(request) + } + contentType(ContentType.Application.Json) + header("x-goog-api-key", key) + header("x-goog-api-client", "genai-android/${BuildConfig.VERSION_NAME}") + } +} + +/** + * Ensures the model name provided has a `models/` prefix + * + * Models must be prepended with the `models/` prefix when communicating with the backend. + */ +private fun fullModelName(name: String): String = name.takeIf { it.contains("/") } ?: "models/$name" + +/** + * Makes a POST request to the specified [url] and returns a [Flow] of deserialized response objects + * of type [R]. The response is expected to be a stream of JSON objects that are parsed in real-time + * as they are received from the server. + * + * This function is intended for internal use within the client that handles streaming responses. + * + * Example usage: + * ``` + * val client: HttpClient = HttpClient(CIO) + * val request: Request = GenerateContentRequest(...) + * val url: String = "http://example.com/stream" + * + * val responses: GenerateContentResponse = client.postStream(url) { + * setBody(request) + * contentType(ContentType.Application.Json) + * } + * responses.collect { + * println("Got a response: $it") + * } + * ``` + * + * @param R The type of the response object. + * @param url The URL to which the POST request will be made. + * @param config An optional [HttpRequestBuilder] callback for request configuration. + * @return A [Flow] of response objects of type [R]. + */ +private inline fun HttpClient.postStream( + url: String, + crossinline config: HttpRequestBuilder.() -> Unit = {} +): Flow = channelFlow { + launch(CoroutineName("postStream")) { + preparePost(url) { config() } + .execute { + validateResponse(it) + + val channel = it.bodyAsChannel() + val flow = JSON.decodeToFlow(channel) + + flow.collect { send(it) } + } + } +} + +private suspend fun validateResponse(response: HttpResponse) { + if (response.status != HttpStatusCode.OK) { + val text = response.bodyAsText() + val message = + try { + JSON.decodeFromString(text).error.message + } catch (e: Throwable) { + "Unexpected Response:\n$text" + } + if (message.contains("API key not valid")) { + throw InvalidAPIKeyException(message) + } + // TODO (b/325117891): Use a better method than string matching. + if (message == "User location is not supported for the API use.") { + throw UnsupportedUserLocationException() + } + throw ServerException(message) + } +} diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Exceptions.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Exceptions.kt new file mode 100644 index 00000000..d28ea2ff --- /dev/null +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Exceptions.kt @@ -0,0 +1,113 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.common + +import io.ktor.serialization.JsonConvertException +import kotlinx.coroutines.TimeoutCancellationException + +/** Parent class for any errors that occur. */ +sealed class GoogleGenerativeAIException(message: String, cause: Throwable? = null) : + RuntimeException(message, cause) { + companion object { + + /** + * Converts a [Throwable] to a [GoogleGenerativeAIException]. + * + * Will populate default messages as expected, and propagate the provided [cause] through the + * resulting exception. + */ + fun from(cause: Throwable): GoogleGenerativeAIException = + when (cause) { + is GoogleGenerativeAIException -> cause + is JsonConvertException, + is kotlinx.serialization.SerializationException -> + SerializationException( + "Something went wrong while trying to deserialize a response from the server.", + cause + ) + is TimeoutCancellationException -> + RequestTimeoutException("The request failed to complete in the allotted time.") + else -> UnknownException("Something unexpected happened.", cause) + } + } +} + +/** Something went wrong while trying to deserialize a response from the server. */ +class SerializationException(message: String, cause: Throwable? = null) : + GoogleGenerativeAIException(message, cause) + +/** The server responded with a non 200 response code. */ +class ServerException(message: String, cause: Throwable? = null) : + GoogleGenerativeAIException(message, cause) + +/** The server responded that the API Key is no valid. */ +class InvalidAPIKeyException(message: String, cause: Throwable? = null) : + GoogleGenerativeAIException(message, cause) + +/** + * A request was blocked for some reason. + * + * See the [response's][response] `promptFeedback.blockReason` for more information. + * + * @property response the full server response for the request. + */ +class PromptBlockedException(val response: GenerateContentResponse, cause: Throwable? = null) : + GoogleGenerativeAIException( + "Prompt was blocked: ${response.promptFeedback?.blockReason?.name}", + cause + ) + +/** + * The user's location (region) is not supported by the API. + * + * See the Google documentation for a + * [list of regions](https://ai.google.dev/available_regions#available_regions) (countries and + * territories) where the API is available. + */ +class UnsupportedUserLocationException(cause: Throwable? = null) : + GoogleGenerativeAIException("User location is not supported for the API use.", cause) + +/** + * Some form of state occurred that shouldn't have. + * + * Usually indicative of consumer error. + */ +class InvalidStateException(message: String, cause: Throwable? = null) : + GoogleGenerativeAIException(message, cause) + +/** + * A request was stopped during generation for some reason. + * + * @property response the full server response for the request + */ +class ResponseStoppedException(val response: GenerateContentResponse, cause: Throwable? = null) : + GoogleGenerativeAIException( + "Content generation stopped. Reason: ${response.candidates?.first()?.finishReason?.name}", + cause + ) + +/** + * A request took too long to complete. + * + * Usually occurs due to a user specified [timeout][RequestOptions.timeout]. + */ +class RequestTimeoutException(message: String, cause: Throwable? = null) : + GoogleGenerativeAIException(message, cause) + +/** Catch all case for exceptions not explicitly expected. */ +class UnknownException(message: String, cause: Throwable? = null) : + GoogleGenerativeAIException(message, cause) diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt new file mode 100644 index 00000000..89afce30 --- /dev/null +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt @@ -0,0 +1,36 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.common + +import com.google.ai.client.generativeai.common.client.GenerationConfig +import com.google.ai.client.generativeai.common.shared.Content +import com.google.ai.client.generativeai.common.shared.SafetySetting +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +sealed interface Request + +@Serializable +data class GenerateContentRequest( + val model: String, + val contents: List, + @SerialName("safety_settings") val safetySettings: List? = null, + @SerialName("generation_config") val generationConfig: GenerationConfig? = null, +) : Request + +@Serializable +data class CountTokensRequest(val model: String, val contents: List) : Request diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/RequestOptions.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/RequestOptions.kt new file mode 100644 index 00000000..82ea8842 --- /dev/null +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/RequestOptions.kt @@ -0,0 +1,40 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.common + +import io.ktor.client.plugins.HttpTimeout +import kotlin.time.Duration +import kotlin.time.DurationUnit +import kotlin.time.toDuration + +/** + * Configurable options unique to how requests to the backend are performed. + * + * @property timeout the maximum amount of time for a request to take, from the first request to + * first response. + * @property apiVersion the api endpoint to call. + */ +class RequestOptions(val timeout: Duration, val apiVersion: String = "v1") { + @JvmOverloads + constructor( + timeout: Long? = HttpTimeout.INFINITE_TIMEOUT_MS, + apiVersion: String = "v1" + ) : this( + (timeout ?: HttpTimeout.INFINITE_TIMEOUT_MS).toDuration(DurationUnit.MILLISECONDS), + apiVersion + ) +} diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Response.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Response.kt new file mode 100644 index 00000000..c8801e37 --- /dev/null +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Response.kt @@ -0,0 +1,34 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.common + +import com.google.ai.client.generativeai.common.server.Candidate +import com.google.ai.client.generativeai.common.server.GRpcError +import com.google.ai.client.generativeai.common.server.PromptFeedback +import kotlinx.serialization.Serializable + +sealed interface Response + +@Serializable +data class GenerateContentResponse( + val candidates: List? = null, + val promptFeedback: PromptFeedback? = null, +) : Response + +@Serializable data class CountTokensResponse(val totalTokens: Int) : Response + +@Serializable data class GRpcErrorResponse(val error: GRpcError) : Response diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt new file mode 100644 index 00000000..bacbb1f1 --- /dev/null +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt @@ -0,0 +1,30 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.common.client + +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +@Serializable +data class GenerationConfig( + val temperature: Float?, + @SerialName("top_p") val topP: Float?, + @SerialName("top_k") val topK: Int?, + @SerialName("candidate_count") val candidateCount: Int?, + @SerialName("max_output_tokens") val maxOutputTokens: Int?, + @SerialName("stop_sequences") val stopSequences: List? +) diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/server/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/server/Types.kt new file mode 100644 index 00000000..cd5a0d66 --- /dev/null +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/server/Types.kt @@ -0,0 +1,104 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.common.server + +import com.google.ai.client.generativeai.common.shared.Content +import com.google.ai.client.generativeai.common.shared.HarmCategory +import com.google.ai.client.generativeai.common.util.FirstOrdinalSerializer +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.KSerializer +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonNames + +object BlockReasonSerializer : + KSerializer by FirstOrdinalSerializer(BlockReason::class) + +object HarmProbabilitySerializer : + KSerializer by FirstOrdinalSerializer(HarmProbability::class) + +object FinishReasonSerializer : + KSerializer by FirstOrdinalSerializer(FinishReason::class) + +@Serializable +data class PromptFeedback( + val blockReason: BlockReason? = null, + val safetyRatings: List? = null, +) + +@Serializable(BlockReasonSerializer::class) +enum class BlockReason { + UNKNOWN, + @SerialName("BLOCKED_REASON_UNSPECIFIED") UNSPECIFIED, + SAFETY, + OTHER +} + +@Serializable +data class Candidate( + val content: Content? = null, + val finishReason: FinishReason? = null, + val safetyRatings: List? = null, + val citationMetadata: CitationMetadata? = null +) + +@Serializable +data class CitationMetadata +@OptIn(ExperimentalSerializationApi::class) +constructor(@JsonNames("citations") val citationSources: List) + +@Serializable +data class CitationSources( + val startIndex: Int, + val endIndex: Int, + val uri: String, + val license: String +) + +@Serializable +data class SafetyRating( + val category: HarmCategory, + val probability: HarmProbability, + val blocked: Boolean? = null // TODO(): any reason not to default to false? +) + +@Serializable(HarmProbabilitySerializer::class) +enum class HarmProbability { + UNKNOWN, + @SerialName("HARM_PROBABILITY_UNSPECIFIED") UNSPECIFIED, + NEGLIGIBLE, + LOW, + MEDIUM, + HIGH +} + +@Serializable(FinishReasonSerializer::class) +enum class FinishReason { + UNKNOWN, + @SerialName("FINISH_REASON_UNSPECIFIED") UNSPECIFIED, + STOP, + MAX_TOKENS, + SAFETY, + RECITATION, + OTHER +} + +@Serializable +data class GRpcError( + val code: Int, + val message: String, +) diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt new file mode 100644 index 00000000..8fc641be --- /dev/null +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt @@ -0,0 +1,82 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.common.shared + +import com.google.ai.client.generativeai.common.util.FirstOrdinalSerializer +import kotlinx.serialization.DeserializationStrategy +import kotlinx.serialization.EncodeDefault +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.KSerializer +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.SerializationException +import kotlinx.serialization.json.JsonContentPolymorphicSerializer +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.jsonObject + +object HarmCategorySerializer : + KSerializer by FirstOrdinalSerializer(HarmCategory::class) + +@Serializable(HarmCategorySerializer::class) +enum class HarmCategory { + UNKNOWN, + @SerialName("HARM_CATEGORY_HARASSMENT") HARASSMENT, + @SerialName("HARM_CATEGORY_HATE_SPEECH") HATE_SPEECH, + @SerialName("HARM_CATEGORY_SEXUALLY_EXPLICIT") SEXUALLY_EXPLICIT, + @SerialName("HARM_CATEGORY_DANGEROUS_CONTENT") DANGEROUS_CONTENT +} + +typealias Base64 = String + +@ExperimentalSerializationApi +@Serializable +data class Content(@EncodeDefault val role: String? = "user", val parts: List) + +@Serializable(PartSerializer::class) sealed interface Part + +@Serializable data class TextPart(val text: String) : Part + +@Serializable data class BlobPart(@SerialName("inline_data") val inlineData: Blob) : Part + +@Serializable +data class Blob( + @SerialName("mime_type") val mimeType: String, + val data: Base64, +) + +@Serializable +data class SafetySetting(val category: HarmCategory, val threshold: HarmBlockThreshold) + +@Serializable +enum class HarmBlockThreshold { + @SerialName("HARM_BLOCK_THRESHOLD_UNSPECIFIED") UNSPECIFIED, + BLOCK_LOW_AND_ABOVE, + BLOCK_MEDIUM_AND_ABOVE, + BLOCK_ONLY_HIGH, + BLOCK_NONE, +} + +object PartSerializer : JsonContentPolymorphicSerializer(Part::class) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy { + val jsonObject = element.jsonObject + return when { + "text" in jsonObject -> TextPart.serializer() + "inlineData" in jsonObject -> BlobPart.serializer() + else -> throw SerializationException("Unknown Part type") + } + } +} diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/kotlin.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/kotlin.kt new file mode 100644 index 00000000..8f681542 --- /dev/null +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/kotlin.kt @@ -0,0 +1,41 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.common.util + +import java.lang.reflect.Field + +/** + * Removes the last character from the [StringBuilder]. + * + * If the StringBuilder is empty, calling this function will throw an [IndexOutOfBoundsException]. + * + * @return The [StringBuilder] used to make the call, for optional chaining. + * @throws IndexOutOfBoundsException if the StringBuilder is empty. + */ +internal fun StringBuilder.removeLast(): StringBuilder = + if (isEmpty()) throw IndexOutOfBoundsException("StringBuilder is empty.") + else deleteCharAt(length - 1) + +/** + * A variant of [getAnnotation][Field.getAnnotation] that provides implicit Kotlin support. + * + * Syntax sugar for: + * ``` + * getAnnotation(T::class.java) + * ``` + */ +internal inline fun Field.getAnnotation() = getAnnotation(T::class.java) diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/ktor.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/ktor.kt new file mode 100644 index 00000000..04eb18dc --- /dev/null +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/ktor.kt @@ -0,0 +1,101 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@file:Suppress("DEPRECATION") // a replacement for our purposes has not been published yet + +package com.google.ai.client.generativeai.common.util + +import io.ktor.utils.io.ByteChannel +import io.ktor.utils.io.ByteReadChannel +import io.ktor.utils.io.close +import io.ktor.utils.io.readUTF8Line +import io.ktor.utils.io.writeFully +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.channelFlow +import kotlinx.serialization.SerializationException +import kotlinx.serialization.json.Json + +/** + * Suspends and processes each line read from the [ByteReadChannel] until the channel is closed for + * read. + * + * This extension function facilitates processing the stream of lines in a manner that takes into + * account EOF/empty strings- and avoids calling [block] as such. + * + * Example usage: + * ``` + * val channel: ByteReadChannel = ByteReadChannel("Hello, World!") + * channel.onEachLine { + * println("Received line: $it") + * } + * ``` + * + * @param block A suspending function to process each line. + */ +internal suspend fun ByteReadChannel.onEachLine(block: suspend (String) -> Unit) { + while (!isClosedForRead) { + awaitContent() + val line = readUTF8Line()?.takeUnless { it.isEmpty() } ?: continue + block(line) + } +} + +/** + * Decodes a stream of JSON elements from the given [ByteReadChannel] into a [Flow] of objects of + * type [T]. + * + * This function takes in a stream of events, each with a set of named parts. Parts are separated by + * an HTTP \r\n newline, events are separated by a double HTTP \r\n\r\n newline. This function + * assumes every event will only contain a named "data" part with a JSON object. Each data JSON is + * decoded into an instance of [T] and emitted as it is read from the channel. + * + * Example usage: + * ``` + * val json = Json { ignoreUnknownKeys = true } // Create a Json instance with any configurations + * val channel: ByteReadChannel = ByteReadChannel("data: {\"name\":\"Alice\"}\r\n\r\ndata: {\"name\":\"Bob\"}]") + * + * json.decodeToFlow(channel).collect { person -> + * println(person.name) + * } + * ``` + * + * @param T The type of objects to decode from the JSON stream. + * @param channel The [ByteReadChannel] from which the JSON stream will be read. + * @return A [Flow] of objects of type [T]. + * @throws SerializationException in case of any decoding-specific error + * @throws IllegalArgumentException if the decoded input is not a valid instance of [T] + */ +internal inline fun Json.decodeToFlow(channel: ByteReadChannel): Flow = channelFlow { + channel.onEachLine { + val data = it.removePrefix("data:") + send(decodeFromString(data)) + } +} + +/** + * Writes the provided [bytes] to the channel and closes it. + * + * Just a wrapper around [writeFully] that closes the channel after writing is complete. + * + * @param bytes the data to send through the channel + */ +internal suspend fun ByteChannel.send(bytes: ByteArray) { + writeFully(bytes) + close() +} + +/** String separator used in SSE communication to signal the end of a message. */ +internal const val SSE_SEPARATOR = "\r\n\r\n" diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/serialization.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/serialization.kt new file mode 100644 index 00000000..d818c266 --- /dev/null +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/serialization.kt @@ -0,0 +1,83 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.common.util + +import android.util.Log +import com.google.ai.client.generativeai.common.SerializationException +import kotlin.reflect.KClass +import kotlinx.serialization.KSerializer +import kotlinx.serialization.SerialName +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.descriptors.buildClassSerialDescriptor +import kotlinx.serialization.encoding.Decoder +import kotlinx.serialization.encoding.Encoder + +/** + * Serializer for enums that defaults to the first ordinal on unknown types. + * + * Convention is that the first enum be named `UNKNOWN`, but any name is valid. + * + * When an unknown enum value is found, the enum itself will be logged to stderr with a message + * about opening an issue on GitHub regarding the new enum value. + */ +class FirstOrdinalSerializer>(private val enumClass: KClass) : KSerializer { + override val descriptor: SerialDescriptor = buildClassSerialDescriptor("FirstOrdinalSerializer") + + override fun deserialize(decoder: Decoder): T { + val name = decoder.decodeString() + val values = enumClass.enumValues() + + return values.firstOrNull { it.serialName == name } + ?: values.first().also { printWarning(name) } + } + + private fun printWarning(name: String) { + Log.e( + "FirstOrdinalSerializer", + """ + |Unknown enum value found: $name" + |This usually means the backend was updated, and the SDK needs to be updated to match it. + |Check if there's a new version for the SDK, otherwise please open an issue on our + |GitHub to bring it to our attention: + |https://github.com/google/google-ai-android + """ + .trimMargin() + ) + } + + override fun serialize(encoder: Encoder, value: T) { + encoder.encodeString(value.serialName) + } +} + +/** + * Provides the name to be used in serialization for this enum value. + * + * By default an enum is serialized to its [name][Enum.name], and can be overwritten by providing a + * [SerialName] annotation. + */ +val > T.serialName: String + get() = declaringJavaClass.getField(name).getAnnotation()?.value ?: name + +/** + * Variant of [kotlin.enumValues] that provides support for [KClass] instances of enums. + * + * @throws SerializationException if the class is not a valid enum. Beyond runtime emily magic, this + * shouldn't really be possible. + */ +fun > KClass.enumValues(): Array = + java.enumConstants ?: throw SerializationException("$simpleName is not a valid enum type.") diff --git a/settings.gradle.kts b/settings.gradle.kts index 44f363aa..3f4e53af 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -31,4 +31,5 @@ dependencyResolutionManagement { rootProject.name = "generativeai" include(":generativeai") +include(":common") includeBuild("./plugins")