diff --git a/common/.gitignore b/common/.gitignore new file mode 100644 index 00000000..796b96d1 --- /dev/null +++ b/common/.gitignore @@ -0,0 +1 @@ +/build diff --git a/common/build.gradle.kts b/common/build.gradle.kts new file mode 100644 index 00000000..bd02f5fa --- /dev/null +++ b/common/build.gradle.kts @@ -0,0 +1,126 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { + id("com.android.library") + id("maven-publish") + id("com.ncorti.ktfmt.gradle") + id("changelog-plugin") + id("release-plugin") + kotlin("android") + kotlin("plugin.serialization") +} + +ktfmt { + googleStyle() +} + +android { + namespace = "com.google.ai.client.generativeai.common" + compileSdk = 34 + + buildFeatures.buildConfig = true + + defaultConfig { + minSdk = 21 + + testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner" + consumerProguardFiles("consumer-rules.pro") + + buildConfigField("String", "VERSION_NAME", "\"${project.version.toString()}\"") + } + + publishing { + singleVariant("release") { + withSourcesJar() + } + } + + buildTypes { + release { + isMinifyEnabled = false + proguardFiles( + getDefaultProguardFile("proguard-android-optimize.txt"), + "proguard-rules.pro" + ) + } + } + compileOptions { + sourceCompatibility = JavaVersion.VERSION_17 + targetCompatibility = JavaVersion.VERSION_17 + } + kotlinOptions { + jvmTarget = "17" + } + + testOptions { + unitTests.isReturnDefaultValues = true + } +} + +dependencies { + val ktorVersion = "2.3.2" + + implementation("io.ktor:ktor-client-okhttp:$ktorVersion") + implementation("io.ktor:ktor-client-core:$ktorVersion") + implementation("io.ktor:ktor-client-content-negotiation:$ktorVersion") + implementation("io.ktor:ktor-serialization-kotlinx-json:$ktorVersion") + implementation("io.ktor:ktor-client-logging:$ktorVersion") + + implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1") + implementation("androidx.core:core-ktx:1.12.0") + implementation("org.slf4j:slf4j-nop:2.0.9") + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.7.3") + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-reactive:1.7.3") + implementation("org.reactivestreams:reactive-streams:1.0.3") + + implementation("com.google.guava:listenablefuture:1.0") + implementation("androidx.concurrent:concurrent-futures:1.2.0-alpha02") + implementation("androidx.concurrent:concurrent-futures-ktx:1.2.0-alpha02") + testImplementation("junit:junit:4.13.2") + testImplementation("io.kotest:kotest-assertions-core:4.0.7") + testImplementation("io.kotest:kotest-assertions-jvm:4.0.7") + testImplementation("io.kotest:kotest-assertions-json:4.0.7") + testImplementation("io.ktor:ktor-client-mock:$ktorVersion") + androidTestImplementation("androidx.test.ext:junit:1.1.5") + androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1") +} + +publishing { + publications { + register("release") { + groupId = "com.google.ai.client.generativeai" + artifactId = "common" + version = project.version.toString() + pom { + licenses { + license { + name = "The Apache License, Version 2.0" + url = "http://www.apache.org/licenses/LICENSE-2.0.txt" + } + } + } + afterEvaluate { + from(components["release"]) + } + } + } + repositories { + maven { + url = uri("${projectDir}/m2") + } + } +} diff --git a/common/consumer-rules.pro b/common/consumer-rules.pro new file mode 100644 index 00000000..f1b42451 --- /dev/null +++ b/common/consumer-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/common/gradle.properties b/common/gradle.properties new file mode 100644 index 00000000..de55ab65 --- /dev/null +++ b/common/gradle.properties @@ -0,0 +1 @@ +version=0.1.0 diff --git a/common/proguard-rules.pro b/common/proguard-rules.pro new file mode 100644 index 00000000..f1b42451 --- /dev/null +++ b/common/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/common/src/main/AndroidManifest.xml b/common/src/main/AndroidManifest.xml new file mode 100644 index 00000000..7ebb9c28 --- /dev/null +++ b/common/src/main/AndroidManifest.xml @@ -0,0 +1,18 @@ + + + + + diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt new file mode 100644 index 00000000..57e5413a --- /dev/null +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt @@ -0,0 +1,185 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.common + +import com.google.ai.client.generativeai.common.util.decodeToFlow +import io.ktor.client.HttpClient +import io.ktor.client.call.body +import io.ktor.client.engine.HttpClientEngine +import io.ktor.client.engine.okhttp.OkHttp +import io.ktor.client.plugins.HttpTimeout +import io.ktor.client.plugins.contentnegotiation.ContentNegotiation +import io.ktor.client.request.HttpRequestBuilder +import io.ktor.client.request.header +import io.ktor.client.request.post +import io.ktor.client.request.preparePost +import io.ktor.client.request.setBody +import io.ktor.client.statement.HttpResponse +import io.ktor.client.statement.bodyAsChannel +import io.ktor.client.statement.bodyAsText +import io.ktor.http.ContentType +import io.ktor.http.HttpStatusCode +import io.ktor.http.contentType +import io.ktor.serialization.kotlinx.json.json +import kotlinx.coroutines.CoroutineName +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.channelFlow +import kotlinx.coroutines.flow.timeout +import kotlinx.coroutines.launch +import kotlinx.serialization.json.Json + +const val DOMAIN = "https://generativelanguage.googleapis.com" + +val JSON = Json { + ignoreUnknownKeys = true + prettyPrint = false +} + +/** + * Backend class for interfacing with the Gemini API. + * + * This class handles making HTTP requests to the API and streaming the responses back. + * + * @param httpEngine The HTTP client engine to be used for making requests. Defaults to CIO engine. + * Exposed primarily for DI in tests. + * @property key The API key used for authentication. + * @property model The model to use for generation. + * @property apiVersion the endpoint version to communicate with. + * @property timeout the maximum amount of time for a request to take in the initial exchange. + */ +class APIController( + private val key: String, + model: String, + private val requestOptions: RequestOptions, + httpEngine: HttpClientEngine = OkHttp.create(), +) { + private val model = fullModelName(model) + + private val client = + HttpClient(httpEngine) { + install(HttpTimeout) { + requestTimeoutMillis = requestOptions.timeout.inWholeMilliseconds + socketTimeoutMillis = 80_000 + } + install(ContentNegotiation) { json(JSON) } + } + + suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse = + client + .post("$DOMAIN/${requestOptions.apiVersion}/$model:generateContent") { + applyCommonConfiguration(request) + } + .also { validateResponse(it) } + .body() + + fun generateContentStream(request: GenerateContentRequest): Flow { + return client.postStream( + "$DOMAIN/${requestOptions.apiVersion}/$model:streamGenerateContent?alt=sse" + ) { + applyCommonConfiguration(request) + } + } + + suspend fun countTokens(request: CountTokensRequest): CountTokensResponse = + client + .post("$DOMAIN/${requestOptions.apiVersion}/$model:countTokens") { + applyCommonConfiguration(request) + } + .also { validateResponse(it) } + .body() + + private fun HttpRequestBuilder.applyCommonConfiguration(request: Request) { + when (request) { + is GenerateContentRequest -> setBody(request) + is CountTokensRequest -> setBody(request) + } + contentType(ContentType.Application.Json) + header("x-goog-api-key", key) + header("x-goog-api-client", "genai-android/${BuildConfig.VERSION_NAME}") + } +} + +/** + * Ensures the model name provided has a `models/` prefix + * + * Models must be prepended with the `models/` prefix when communicating with the backend. + */ +private fun fullModelName(name: String): String = name.takeIf { it.contains("/") } ?: "models/$name" + +/** + * Makes a POST request to the specified [url] and returns a [Flow] of deserialized response objects + * of type [R]. The response is expected to be a stream of JSON objects that are parsed in real-time + * as they are received from the server. + * + * This function is intended for internal use within the client that handles streaming responses. + * + * Example usage: + * ``` + * val client: HttpClient = HttpClient(CIO) + * val request: Request = GenerateContentRequest(...) + * val url: String = "http://example.com/stream" + * + * val responses: GenerateContentResponse = client.postStream(url) { + * setBody(request) + * contentType(ContentType.Application.Json) + * } + * responses.collect { + * println("Got a response: $it") + * } + * ``` + * + * @param R The type of the response object. + * @param url The URL to which the POST request will be made. + * @param config An optional [HttpRequestBuilder] callback for request configuration. + * @return A [Flow] of response objects of type [R]. + */ +private inline fun HttpClient.postStream( + url: String, + crossinline config: HttpRequestBuilder.() -> Unit = {} +): Flow = channelFlow { + launch(CoroutineName("postStream")) { + preparePost(url) { config() } + .execute { + validateResponse(it) + + val channel = it.bodyAsChannel() + val flow = JSON.decodeToFlow(channel) + + flow.collect { send(it) } + } + } +} + +private suspend fun validateResponse(response: HttpResponse) { + if (response.status != HttpStatusCode.OK) { + val text = response.bodyAsText() + val message = + try { + JSON.decodeFromString(text).error.message + } catch (e: Throwable) { + "Unexpected Response:\n$text" + } + if (message.contains("API key not valid")) { + throw InvalidAPIKeyException(message) + } + // TODO (b/325117891): Use a better method than string matching. + if (message == "User location is not supported for the API use.") { + throw UnsupportedUserLocationException() + } + throw ServerException(message) + } +} diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Exceptions.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Exceptions.kt new file mode 100644 index 00000000..d28ea2ff --- /dev/null +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Exceptions.kt @@ -0,0 +1,113 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.common + +import io.ktor.serialization.JsonConvertException +import kotlinx.coroutines.TimeoutCancellationException + +/** Parent class for any errors that occur. */ +sealed class GoogleGenerativeAIException(message: String, cause: Throwable? = null) : + RuntimeException(message, cause) { + companion object { + + /** + * Converts a [Throwable] to a [GoogleGenerativeAIException]. + * + * Will populate default messages as expected, and propagate the provided [cause] through the + * resulting exception. + */ + fun from(cause: Throwable): GoogleGenerativeAIException = + when (cause) { + is GoogleGenerativeAIException -> cause + is JsonConvertException, + is kotlinx.serialization.SerializationException -> + SerializationException( + "Something went wrong while trying to deserialize a response from the server.", + cause + ) + is TimeoutCancellationException -> + RequestTimeoutException("The request failed to complete in the allotted time.") + else -> UnknownException("Something unexpected happened.", cause) + } + } +} + +/** Something went wrong while trying to deserialize a response from the server. */ +class SerializationException(message: String, cause: Throwable? = null) : + GoogleGenerativeAIException(message, cause) + +/** The server responded with a non 200 response code. */ +class ServerException(message: String, cause: Throwable? = null) : + GoogleGenerativeAIException(message, cause) + +/** The server responded that the API Key is no valid. */ +class InvalidAPIKeyException(message: String, cause: Throwable? = null) : + GoogleGenerativeAIException(message, cause) + +/** + * A request was blocked for some reason. + * + * See the [response's][response] `promptFeedback.blockReason` for more information. + * + * @property response the full server response for the request. + */ +class PromptBlockedException(val response: GenerateContentResponse, cause: Throwable? = null) : + GoogleGenerativeAIException( + "Prompt was blocked: ${response.promptFeedback?.blockReason?.name}", + cause + ) + +/** + * The user's location (region) is not supported by the API. + * + * See the Google documentation for a + * [list of regions](https://ai.google.dev/available_regions#available_regions) (countries and + * territories) where the API is available. + */ +class UnsupportedUserLocationException(cause: Throwable? = null) : + GoogleGenerativeAIException("User location is not supported for the API use.", cause) + +/** + * Some form of state occurred that shouldn't have. + * + * Usually indicative of consumer error. + */ +class InvalidStateException(message: String, cause: Throwable? = null) : + GoogleGenerativeAIException(message, cause) + +/** + * A request was stopped during generation for some reason. + * + * @property response the full server response for the request + */ +class ResponseStoppedException(val response: GenerateContentResponse, cause: Throwable? = null) : + GoogleGenerativeAIException( + "Content generation stopped. Reason: ${response.candidates?.first()?.finishReason?.name}", + cause + ) + +/** + * A request took too long to complete. + * + * Usually occurs due to a user specified [timeout][RequestOptions.timeout]. + */ +class RequestTimeoutException(message: String, cause: Throwable? = null) : + GoogleGenerativeAIException(message, cause) + +/** Catch all case for exceptions not explicitly expected. */ +class UnknownException(message: String, cause: Throwable? = null) : + GoogleGenerativeAIException(message, cause) diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt new file mode 100644 index 00000000..89afce30 --- /dev/null +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt @@ -0,0 +1,36 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.common + +import com.google.ai.client.generativeai.common.client.GenerationConfig +import com.google.ai.client.generativeai.common.shared.Content +import com.google.ai.client.generativeai.common.shared.SafetySetting +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +sealed interface Request + +@Serializable +data class GenerateContentRequest( + val model: String, + val contents: List, + @SerialName("safety_settings") val safetySettings: List? = null, + @SerialName("generation_config") val generationConfig: GenerationConfig? = null, +) : Request + +@Serializable +data class CountTokensRequest(val model: String, val contents: List) : Request diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/RequestOptions.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/RequestOptions.kt new file mode 100644 index 00000000..82ea8842 --- /dev/null +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/RequestOptions.kt @@ -0,0 +1,40 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.common + +import io.ktor.client.plugins.HttpTimeout +import kotlin.time.Duration +import kotlin.time.DurationUnit +import kotlin.time.toDuration + +/** + * Configurable options unique to how requests to the backend are performed. + * + * @property timeout the maximum amount of time for a request to take, from the first request to + * first response. + * @property apiVersion the api endpoint to call. + */ +class RequestOptions(val timeout: Duration, val apiVersion: String = "v1") { + @JvmOverloads + constructor( + timeout: Long? = HttpTimeout.INFINITE_TIMEOUT_MS, + apiVersion: String = "v1" + ) : this( + (timeout ?: HttpTimeout.INFINITE_TIMEOUT_MS).toDuration(DurationUnit.MILLISECONDS), + apiVersion + ) +} diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Response.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Response.kt new file mode 100644 index 00000000..c8801e37 --- /dev/null +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Response.kt @@ -0,0 +1,34 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.common + +import com.google.ai.client.generativeai.common.server.Candidate +import com.google.ai.client.generativeai.common.server.GRpcError +import com.google.ai.client.generativeai.common.server.PromptFeedback +import kotlinx.serialization.Serializable + +sealed interface Response + +@Serializable +data class GenerateContentResponse( + val candidates: List? = null, + val promptFeedback: PromptFeedback? = null, +) : Response + +@Serializable data class CountTokensResponse(val totalTokens: Int) : Response + +@Serializable data class GRpcErrorResponse(val error: GRpcError) : Response diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt new file mode 100644 index 00000000..bacbb1f1 --- /dev/null +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt @@ -0,0 +1,30 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.common.client + +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +@Serializable +data class GenerationConfig( + val temperature: Float?, + @SerialName("top_p") val topP: Float?, + @SerialName("top_k") val topK: Int?, + @SerialName("candidate_count") val candidateCount: Int?, + @SerialName("max_output_tokens") val maxOutputTokens: Int?, + @SerialName("stop_sequences") val stopSequences: List? +) diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/server/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/server/Types.kt new file mode 100644 index 00000000..cd5a0d66 --- /dev/null +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/server/Types.kt @@ -0,0 +1,104 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.common.server + +import com.google.ai.client.generativeai.common.shared.Content +import com.google.ai.client.generativeai.common.shared.HarmCategory +import com.google.ai.client.generativeai.common.util.FirstOrdinalSerializer +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.KSerializer +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonNames + +object BlockReasonSerializer : + KSerializer by FirstOrdinalSerializer(BlockReason::class) + +object HarmProbabilitySerializer : + KSerializer by FirstOrdinalSerializer(HarmProbability::class) + +object FinishReasonSerializer : + KSerializer by FirstOrdinalSerializer(FinishReason::class) + +@Serializable +data class PromptFeedback( + val blockReason: BlockReason? = null, + val safetyRatings: List? = null, +) + +@Serializable(BlockReasonSerializer::class) +enum class BlockReason { + UNKNOWN, + @SerialName("BLOCKED_REASON_UNSPECIFIED") UNSPECIFIED, + SAFETY, + OTHER +} + +@Serializable +data class Candidate( + val content: Content? = null, + val finishReason: FinishReason? = null, + val safetyRatings: List? = null, + val citationMetadata: CitationMetadata? = null +) + +@Serializable +data class CitationMetadata +@OptIn(ExperimentalSerializationApi::class) +constructor(@JsonNames("citations") val citationSources: List) + +@Serializable +data class CitationSources( + val startIndex: Int, + val endIndex: Int, + val uri: String, + val license: String +) + +@Serializable +data class SafetyRating( + val category: HarmCategory, + val probability: HarmProbability, + val blocked: Boolean? = null // TODO(): any reason not to default to false? +) + +@Serializable(HarmProbabilitySerializer::class) +enum class HarmProbability { + UNKNOWN, + @SerialName("HARM_PROBABILITY_UNSPECIFIED") UNSPECIFIED, + NEGLIGIBLE, + LOW, + MEDIUM, + HIGH +} + +@Serializable(FinishReasonSerializer::class) +enum class FinishReason { + UNKNOWN, + @SerialName("FINISH_REASON_UNSPECIFIED") UNSPECIFIED, + STOP, + MAX_TOKENS, + SAFETY, + RECITATION, + OTHER +} + +@Serializable +data class GRpcError( + val code: Int, + val message: String, +) diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt new file mode 100644 index 00000000..8fc641be --- /dev/null +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt @@ -0,0 +1,82 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.common.shared + +import com.google.ai.client.generativeai.common.util.FirstOrdinalSerializer +import kotlinx.serialization.DeserializationStrategy +import kotlinx.serialization.EncodeDefault +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.KSerializer +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.SerializationException +import kotlinx.serialization.json.JsonContentPolymorphicSerializer +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.jsonObject + +object HarmCategorySerializer : + KSerializer by FirstOrdinalSerializer(HarmCategory::class) + +@Serializable(HarmCategorySerializer::class) +enum class HarmCategory { + UNKNOWN, + @SerialName("HARM_CATEGORY_HARASSMENT") HARASSMENT, + @SerialName("HARM_CATEGORY_HATE_SPEECH") HATE_SPEECH, + @SerialName("HARM_CATEGORY_SEXUALLY_EXPLICIT") SEXUALLY_EXPLICIT, + @SerialName("HARM_CATEGORY_DANGEROUS_CONTENT") DANGEROUS_CONTENT +} + +typealias Base64 = String + +@ExperimentalSerializationApi +@Serializable +data class Content(@EncodeDefault val role: String? = "user", val parts: List) + +@Serializable(PartSerializer::class) sealed interface Part + +@Serializable data class TextPart(val text: String) : Part + +@Serializable data class BlobPart(@SerialName("inline_data") val inlineData: Blob) : Part + +@Serializable +data class Blob( + @SerialName("mime_type") val mimeType: String, + val data: Base64, +) + +@Serializable +data class SafetySetting(val category: HarmCategory, val threshold: HarmBlockThreshold) + +@Serializable +enum class HarmBlockThreshold { + @SerialName("HARM_BLOCK_THRESHOLD_UNSPECIFIED") UNSPECIFIED, + BLOCK_LOW_AND_ABOVE, + BLOCK_MEDIUM_AND_ABOVE, + BLOCK_ONLY_HIGH, + BLOCK_NONE, +} + +object PartSerializer : JsonContentPolymorphicSerializer(Part::class) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy { + val jsonObject = element.jsonObject + return when { + "text" in jsonObject -> TextPart.serializer() + "inlineData" in jsonObject -> BlobPart.serializer() + else -> throw SerializationException("Unknown Part type") + } + } +} diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/kotlin.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/kotlin.kt new file mode 100644 index 00000000..8f681542 --- /dev/null +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/kotlin.kt @@ -0,0 +1,41 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.common.util + +import java.lang.reflect.Field + +/** + * Removes the last character from the [StringBuilder]. + * + * If the StringBuilder is empty, calling this function will throw an [IndexOutOfBoundsException]. + * + * @return The [StringBuilder] used to make the call, for optional chaining. + * @throws IndexOutOfBoundsException if the StringBuilder is empty. + */ +internal fun StringBuilder.removeLast(): StringBuilder = + if (isEmpty()) throw IndexOutOfBoundsException("StringBuilder is empty.") + else deleteCharAt(length - 1) + +/** + * A variant of [getAnnotation][Field.getAnnotation] that provides implicit Kotlin support. + * + * Syntax sugar for: + * ``` + * getAnnotation(T::class.java) + * ``` + */ +internal inline fun Field.getAnnotation() = getAnnotation(T::class.java) diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/ktor.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/ktor.kt new file mode 100644 index 00000000..04eb18dc --- /dev/null +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/ktor.kt @@ -0,0 +1,101 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@file:Suppress("DEPRECATION") // a replacement for our purposes has not been published yet + +package com.google.ai.client.generativeai.common.util + +import io.ktor.utils.io.ByteChannel +import io.ktor.utils.io.ByteReadChannel +import io.ktor.utils.io.close +import io.ktor.utils.io.readUTF8Line +import io.ktor.utils.io.writeFully +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.channelFlow +import kotlinx.serialization.SerializationException +import kotlinx.serialization.json.Json + +/** + * Suspends and processes each line read from the [ByteReadChannel] until the channel is closed for + * read. + * + * This extension function facilitates processing the stream of lines in a manner that takes into + * account EOF/empty strings- and avoids calling [block] as such. + * + * Example usage: + * ``` + * val channel: ByteReadChannel = ByteReadChannel("Hello, World!") + * channel.onEachLine { + * println("Received line: $it") + * } + * ``` + * + * @param block A suspending function to process each line. + */ +internal suspend fun ByteReadChannel.onEachLine(block: suspend (String) -> Unit) { + while (!isClosedForRead) { + awaitContent() + val line = readUTF8Line()?.takeUnless { it.isEmpty() } ?: continue + block(line) + } +} + +/** + * Decodes a stream of JSON elements from the given [ByteReadChannel] into a [Flow] of objects of + * type [T]. + * + * This function takes in a stream of events, each with a set of named parts. Parts are separated by + * an HTTP \r\n newline, events are separated by a double HTTP \r\n\r\n newline. This function + * assumes every event will only contain a named "data" part with a JSON object. Each data JSON is + * decoded into an instance of [T] and emitted as it is read from the channel. + * + * Example usage: + * ``` + * val json = Json { ignoreUnknownKeys = true } // Create a Json instance with any configurations + * val channel: ByteReadChannel = ByteReadChannel("data: {\"name\":\"Alice\"}\r\n\r\ndata: {\"name\":\"Bob\"}]") + * + * json.decodeToFlow(channel).collect { person -> + * println(person.name) + * } + * ``` + * + * @param T The type of objects to decode from the JSON stream. + * @param channel The [ByteReadChannel] from which the JSON stream will be read. + * @return A [Flow] of objects of type [T]. + * @throws SerializationException in case of any decoding-specific error + * @throws IllegalArgumentException if the decoded input is not a valid instance of [T] + */ +internal inline fun Json.decodeToFlow(channel: ByteReadChannel): Flow = channelFlow { + channel.onEachLine { + val data = it.removePrefix("data:") + send(decodeFromString(data)) + } +} + +/** + * Writes the provided [bytes] to the channel and closes it. + * + * Just a wrapper around [writeFully] that closes the channel after writing is complete. + * + * @param bytes the data to send through the channel + */ +internal suspend fun ByteChannel.send(bytes: ByteArray) { + writeFully(bytes) + close() +} + +/** String separator used in SSE communication to signal the end of a message. */ +internal const val SSE_SEPARATOR = "\r\n\r\n" diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/serialization.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/serialization.kt new file mode 100644 index 00000000..d818c266 --- /dev/null +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/serialization.kt @@ -0,0 +1,83 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.common.util + +import android.util.Log +import com.google.ai.client.generativeai.common.SerializationException +import kotlin.reflect.KClass +import kotlinx.serialization.KSerializer +import kotlinx.serialization.SerialName +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.descriptors.buildClassSerialDescriptor +import kotlinx.serialization.encoding.Decoder +import kotlinx.serialization.encoding.Encoder + +/** + * Serializer for enums that defaults to the first ordinal on unknown types. + * + * Convention is that the first enum be named `UNKNOWN`, but any name is valid. + * + * When an unknown enum value is found, the enum itself will be logged to stderr with a message + * about opening an issue on GitHub regarding the new enum value. + */ +class FirstOrdinalSerializer>(private val enumClass: KClass) : KSerializer { + override val descriptor: SerialDescriptor = buildClassSerialDescriptor("FirstOrdinalSerializer") + + override fun deserialize(decoder: Decoder): T { + val name = decoder.decodeString() + val values = enumClass.enumValues() + + return values.firstOrNull { it.serialName == name } + ?: values.first().also { printWarning(name) } + } + + private fun printWarning(name: String) { + Log.e( + "FirstOrdinalSerializer", + """ + |Unknown enum value found: $name" + |This usually means the backend was updated, and the SDK needs to be updated to match it. + |Check if there's a new version for the SDK, otherwise please open an issue on our + |GitHub to bring it to our attention: + |https://github.com/google/google-ai-android + """ + .trimMargin() + ) + } + + override fun serialize(encoder: Encoder, value: T) { + encoder.encodeString(value.serialName) + } +} + +/** + * Provides the name to be used in serialization for this enum value. + * + * By default an enum is serialized to its [name][Enum.name], and can be overwritten by providing a + * [SerialName] annotation. + */ +val > T.serialName: String + get() = declaringJavaClass.getField(name).getAnnotation()?.value ?: name + +/** + * Variant of [kotlin.enumValues] that provides support for [KClass] instances of enums. + * + * @throws SerializationException if the class is not a valid enum. Beyond runtime emily magic, this + * shouldn't really be possible. + */ +fun > KClass.enumValues(): Array = + java.enumConstants ?: throw SerializationException("$simpleName is not a valid enum type.") diff --git a/settings.gradle.kts b/settings.gradle.kts index 44f363aa..3f4e53af 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -31,4 +31,5 @@ dependencyResolutionManagement { rootProject.name = "generativeai" include(":generativeai") +include(":common") includeBuild("./plugins")