diff --git a/common/.gitignore b/common/.gitignore new file mode 100644 index 00000000..796b96d1 --- /dev/null +++ b/common/.gitignore @@ -0,0 +1 @@ +/build diff --git a/common/build.gradle.kts b/common/build.gradle.kts new file mode 100644 index 00000000..bd02f5fa --- /dev/null +++ b/common/build.gradle.kts @@ -0,0 +1,126 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { + id("com.android.library") + id("maven-publish") + id("com.ncorti.ktfmt.gradle") + id("changelog-plugin") + id("release-plugin") + kotlin("android") + kotlin("plugin.serialization") +} + +ktfmt { + googleStyle() +} + +android { + namespace = "com.google.ai.client.generativeai.common" + compileSdk = 34 + + buildFeatures.buildConfig = true + + defaultConfig { + minSdk = 21 + + testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner" + consumerProguardFiles("consumer-rules.pro") + + buildConfigField("String", "VERSION_NAME", "\"${project.version.toString()}\"") + } + + publishing { + singleVariant("release") { + withSourcesJar() + } + } + + buildTypes { + release { + isMinifyEnabled = false + proguardFiles( + getDefaultProguardFile("proguard-android-optimize.txt"), + "proguard-rules.pro" + ) + } + } + compileOptions { + sourceCompatibility = JavaVersion.VERSION_17 + targetCompatibility = JavaVersion.VERSION_17 + } + kotlinOptions { + jvmTarget = "17" + } + + testOptions { + unitTests.isReturnDefaultValues = true + } +} + +dependencies { + val ktorVersion = "2.3.2" + + implementation("io.ktor:ktor-client-okhttp:$ktorVersion") + implementation("io.ktor:ktor-client-core:$ktorVersion") + implementation("io.ktor:ktor-client-content-negotiation:$ktorVersion") + implementation("io.ktor:ktor-serialization-kotlinx-json:$ktorVersion") + implementation("io.ktor:ktor-client-logging:$ktorVersion") + + implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1") + implementation("androidx.core:core-ktx:1.12.0") + implementation("org.slf4j:slf4j-nop:2.0.9") + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.7.3") + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-reactive:1.7.3") + implementation("org.reactivestreams:reactive-streams:1.0.3") + + implementation("com.google.guava:listenablefuture:1.0") + implementation("androidx.concurrent:concurrent-futures:1.2.0-alpha02") + implementation("androidx.concurrent:concurrent-futures-ktx:1.2.0-alpha02") + testImplementation("junit:junit:4.13.2") + testImplementation("io.kotest:kotest-assertions-core:4.0.7") + testImplementation("io.kotest:kotest-assertions-jvm:4.0.7") + testImplementation("io.kotest:kotest-assertions-json:4.0.7") + testImplementation("io.ktor:ktor-client-mock:$ktorVersion") + androidTestImplementation("androidx.test.ext:junit:1.1.5") + androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1") +} + +publishing { + publications { + register("release") { + groupId = "com.google.ai.client.generativeai" + artifactId = "common" + version = project.version.toString() + pom { + licenses { + license { + name = "The Apache License, Version 2.0" + url = "http://www.apache.org/licenses/LICENSE-2.0.txt" + } + } + } + afterEvaluate { + from(components["release"]) + } + } + } + repositories { + maven { + url = uri("${projectDir}/m2") + } + } +} diff --git a/common/consumer-rules.pro b/common/consumer-rules.pro new file mode 100644 index 00000000..f1b42451 --- /dev/null +++ b/common/consumer-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/common/gradle.properties b/common/gradle.properties new file mode 100644 index 00000000..de55ab65 --- /dev/null +++ b/common/gradle.properties @@ -0,0 +1 @@ +version=0.1.0 diff --git a/common/proguard-rules.pro b/common/proguard-rules.pro new file mode 100644 index 00000000..f1b42451 --- /dev/null +++ b/common/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/common/src/main/AndroidManifest.xml b/common/src/main/AndroidManifest.xml new file mode 100644 index 00000000..7ebb9c28 --- /dev/null +++ b/common/src/main/AndroidManifest.xml @@ -0,0 +1,18 @@ + + + + + diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt similarity index 66% rename from generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt rename to common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt index dc3c343f..13cfccc2 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt @@ -1,5 +1,5 @@ /* - * Copyright 2023 Google LLC + * Copyright 2024 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,14 +14,10 @@ * limitations under the License. */ -package com.google.ai.client.generativeai.internal.api +package com.google.ai.client.generativeai.common -import com.google.ai.client.generativeai.BuildConfig -import com.google.ai.client.generativeai.internal.util.decodeToFlow -import com.google.ai.client.generativeai.type.InvalidAPIKeyException -import com.google.ai.client.generativeai.type.RequestOptions -import com.google.ai.client.generativeai.type.ServerException -import com.google.ai.client.generativeai.type.UnsupportedUserLocationException +import com.google.ai.client.generativeai.common.server.FinishReason +import com.google.ai.client.generativeai.common.util.decodeToFlow import io.ktor.client.HttpClient import io.ktor.client.call.body import io.ktor.client.engine.HttpClientEngine @@ -42,14 +38,16 @@ import io.ktor.http.contentType import io.ktor.serialization.kotlinx.json.json import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.catch import kotlinx.coroutines.flow.channelFlow +import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.timeout import kotlinx.coroutines.launch import kotlinx.serialization.json.Json -internal const val DOMAIN = "https://generativelanguage.googleapis.com" +const val DOMAIN = "https://generativelanguage.googleapis.com" -internal val JSON = Json { +val JSON = Json { ignoreUnknownKeys = true prettyPrint = false } @@ -63,15 +61,21 @@ internal val JSON = Json { * Exposed primarily for DI in tests. * @property key The API key used for authentication. * @property model The model to use for generation. - * @property apiVersion the endpoint version to communicate with. - * @property timeout the maximum amount of time for a request to take in the initial exchange. */ -internal class APIController( +class APIController +internal constructor( private val key: String, model: String, private val requestOptions: RequestOptions, - httpEngine: HttpClientEngine = OkHttp.create(), + httpEngine: HttpClientEngine ) { + + constructor( + key: String, + model: String, + requestOptions: RequestOptions + ) : this(key, model, requestOptions, OkHttp.create()) + private val model = fullModelName(model) private val client = @@ -84,28 +88,39 @@ internal class APIController( } suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse = - client - .post("$DOMAIN/${requestOptions.apiVersion}/$model:generateContent") { - applyCommonConfiguration(request) - } - .also { validateResponse(it) } - .body() - - fun generateContentStream(request: GenerateContentRequest): Flow { - return client.postStream( - "$DOMAIN/${requestOptions.apiVersion}/$model:streamGenerateContent?alt=sse" - ) { - applyCommonConfiguration(request) + try { + client + .post("$DOMAIN/${requestOptions.apiVersion}/$model:generateContent") { + applyCommonConfiguration(request) + } + .also { validateResponse(it) } + .body() + .validate() + } catch (e: Throwable) { + throw GoogleGenerativeAIException.from(e) } - } - suspend fun countTokens(request: CountTokensRequest): CountTokensResponse = + fun generateContentStream(request: GenerateContentRequest): Flow = client - .post("$DOMAIN/${requestOptions.apiVersion}/$model:countTokens") { + .postStream( + "$DOMAIN/${requestOptions.apiVersion}/$model:streamGenerateContent?alt=sse" + ) { applyCommonConfiguration(request) } - .also { validateResponse(it) } - .body() + .map { it.validate() } + .catch { throw GoogleGenerativeAIException.from(it) } + + suspend fun countTokens(request: CountTokensRequest): CountTokensResponse = + try { + client + .post("$DOMAIN/${requestOptions.apiVersion}/$model:countTokens") { + applyCommonConfiguration(request) + } + .also { validateResponse(it) } + .body() + } catch (e: Throwable) { + throw GoogleGenerativeAIException.from(e) + } private fun HttpRequestBuilder.applyCommonConfiguration(request: Request) { when (request) { @@ -170,21 +185,31 @@ private inline fun HttpClient.postStream( } private suspend fun validateResponse(response: HttpResponse) { - if (response.status != HttpStatusCode.OK) { - val text = response.bodyAsText() - val message = - try { - JSON.decodeFromString(text).error.message - } catch (e: Throwable) { - "Unexpected Response:\n$text" - } - if (message.contains("API key not valid")) { - throw InvalidAPIKeyException(message) - } - // TODO (b/325117891): Use a better method than string matching. - if (message == "User location is not supported for the API use.") { - throw UnsupportedUserLocationException() + if (response.status == HttpStatusCode.OK) return + val text = response.bodyAsText() + val message = + try { + JSON.decodeFromString(text).error.message + } catch (e: Throwable) { + "Unexpected Response:\n$text" } - throw ServerException(message) + if (message.contains("API key not valid")) { + throw InvalidAPIKeyException(message) + } + // TODO (b/325117891): Use a better method than string matching. + if (message == "User location is not supported for the API use.") { + throw UnsupportedUserLocationException() + } + throw ServerException(message) +} + +private fun GenerateContentResponse.validate() = apply { + if ((candidates?.isEmpty() != false) && promptFeedback == null) { + throw SerializationException("Error deserializing response, found no valid fields") } + promptFeedback?.blockReason?.let { throw PromptBlockedException(this) } + candidates + ?.mapNotNull { it.finishReason } + ?.firstOrNull { it != FinishReason.STOP } + ?.let { throw ResponseStoppedException(this) } } diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Exceptions.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Exceptions.kt new file mode 100644 index 00000000..d28ea2ff --- /dev/null +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Exceptions.kt @@ -0,0 +1,113 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.common + +import io.ktor.serialization.JsonConvertException +import kotlinx.coroutines.TimeoutCancellationException + +/** Parent class for any errors that occur. */ +sealed class GoogleGenerativeAIException(message: String, cause: Throwable? = null) : + RuntimeException(message, cause) { + companion object { + + /** + * Converts a [Throwable] to a [GoogleGenerativeAIException]. + * + * Will populate default messages as expected, and propagate the provided [cause] through the + * resulting exception. + */ + fun from(cause: Throwable): GoogleGenerativeAIException = + when (cause) { + is GoogleGenerativeAIException -> cause + is JsonConvertException, + is kotlinx.serialization.SerializationException -> + SerializationException( + "Something went wrong while trying to deserialize a response from the server.", + cause + ) + is TimeoutCancellationException -> + RequestTimeoutException("The request failed to complete in the allotted time.") + else -> UnknownException("Something unexpected happened.", cause) + } + } +} + +/** Something went wrong while trying to deserialize a response from the server. */ +class SerializationException(message: String, cause: Throwable? = null) : + GoogleGenerativeAIException(message, cause) + +/** The server responded with a non 200 response code. */ +class ServerException(message: String, cause: Throwable? = null) : + GoogleGenerativeAIException(message, cause) + +/** The server responded that the API Key is no valid. */ +class InvalidAPIKeyException(message: String, cause: Throwable? = null) : + GoogleGenerativeAIException(message, cause) + +/** + * A request was blocked for some reason. + * + * See the [response's][response] `promptFeedback.blockReason` for more information. + * + * @property response the full server response for the request. + */ +class PromptBlockedException(val response: GenerateContentResponse, cause: Throwable? = null) : + GoogleGenerativeAIException( + "Prompt was blocked: ${response.promptFeedback?.blockReason?.name}", + cause + ) + +/** + * The user's location (region) is not supported by the API. + * + * See the Google documentation for a + * [list of regions](https://ai.google.dev/available_regions#available_regions) (countries and + * territories) where the API is available. + */ +class UnsupportedUserLocationException(cause: Throwable? = null) : + GoogleGenerativeAIException("User location is not supported for the API use.", cause) + +/** + * Some form of state occurred that shouldn't have. + * + * Usually indicative of consumer error. + */ +class InvalidStateException(message: String, cause: Throwable? = null) : + GoogleGenerativeAIException(message, cause) + +/** + * A request was stopped during generation for some reason. + * + * @property response the full server response for the request + */ +class ResponseStoppedException(val response: GenerateContentResponse, cause: Throwable? = null) : + GoogleGenerativeAIException( + "Content generation stopped. Reason: ${response.candidates?.first()?.finishReason?.name}", + cause + ) + +/** + * A request took too long to complete. + * + * Usually occurs due to a user specified [timeout][RequestOptions.timeout]. + */ +class RequestTimeoutException(message: String, cause: Throwable? = null) : + GoogleGenerativeAIException(message, cause) + +/** Catch all case for exceptions not explicitly expected. */ +class UnknownException(message: String, cause: Throwable? = null) : + GoogleGenerativeAIException(message, cause) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/Request.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt similarity index 65% rename from generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/Request.kt rename to common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt index 44cbe85d..89afce30 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/Request.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt @@ -1,5 +1,5 @@ /* - * Copyright 2023 Google LLC + * Copyright 2024 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,18 +14,18 @@ * limitations under the License. */ -package com.google.ai.client.generativeai.internal.api +package com.google.ai.client.generativeai.common -import com.google.ai.client.generativeai.internal.api.client.GenerationConfig -import com.google.ai.client.generativeai.internal.api.shared.Content -import com.google.ai.client.generativeai.internal.api.shared.SafetySetting +import com.google.ai.client.generativeai.common.client.GenerationConfig +import com.google.ai.client.generativeai.common.shared.Content +import com.google.ai.client.generativeai.common.shared.SafetySetting import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable -internal sealed interface Request +sealed interface Request @Serializable -internal data class GenerateContentRequest( +data class GenerateContentRequest( val model: String, val contents: List, @SerialName("safety_settings") val safetySettings: List? = null, @@ -33,4 +33,4 @@ internal data class GenerateContentRequest( ) : Request @Serializable -internal data class CountTokensRequest(val model: String, val contents: List) : Request +data class CountTokensRequest(val model: String, val contents: List) : Request diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/RequestOptions.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/RequestOptions.kt new file mode 100644 index 00000000..82ea8842 --- /dev/null +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/RequestOptions.kt @@ -0,0 +1,40 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.common + +import io.ktor.client.plugins.HttpTimeout +import kotlin.time.Duration +import kotlin.time.DurationUnit +import kotlin.time.toDuration + +/** + * Configurable options unique to how requests to the backend are performed. + * + * @property timeout the maximum amount of time for a request to take, from the first request to + * first response. + * @property apiVersion the api endpoint to call. + */ +class RequestOptions(val timeout: Duration, val apiVersion: String = "v1") { + @JvmOverloads + constructor( + timeout: Long? = HttpTimeout.INFINITE_TIMEOUT_MS, + apiVersion: String = "v1" + ) : this( + (timeout ?: HttpTimeout.INFINITE_TIMEOUT_MS).toDuration(DurationUnit.MILLISECONDS), + apiVersion + ) +} diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/Response.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Response.kt similarity index 56% rename from generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/Response.kt rename to common/src/main/kotlin/com/google/ai/client/generativeai/common/Response.kt index df142c4f..c8801e37 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/Response.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Response.kt @@ -1,5 +1,5 @@ /* - * Copyright 2023 Google LLC + * Copyright 2024 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,21 +14,21 @@ * limitations under the License. */ -package com.google.ai.client.generativeai.internal.api +package com.google.ai.client.generativeai.common -import com.google.ai.client.generativeai.internal.api.server.Candidate -import com.google.ai.client.generativeai.internal.api.server.GRpcError -import com.google.ai.client.generativeai.internal.api.server.PromptFeedback +import com.google.ai.client.generativeai.common.server.Candidate +import com.google.ai.client.generativeai.common.server.GRpcError +import com.google.ai.client.generativeai.common.server.PromptFeedback import kotlinx.serialization.Serializable -internal sealed interface Response +sealed interface Response @Serializable -internal data class GenerateContentResponse( +data class GenerateContentResponse( val candidates: List? = null, val promptFeedback: PromptFeedback? = null, ) : Response -@Serializable internal data class CountTokensResponse(val totalTokens: Int) : Response +@Serializable data class CountTokensResponse(val totalTokens: Int) : Response -@Serializable internal data class GRpcErrorResponse(val error: GRpcError) : Response +@Serializable data class GRpcErrorResponse(val error: GRpcError) : Response diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/client/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt similarity index 88% rename from generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/client/Types.kt rename to common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt index f1f778bf..bacbb1f1 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/client/Types.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt @@ -1,5 +1,5 @@ /* - * Copyright 2023 Google LLC + * Copyright 2024 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,13 +14,13 @@ * limitations under the License. */ -package com.google.ai.client.generativeai.internal.api.client +package com.google.ai.client.generativeai.common.client import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable @Serializable -internal data class GenerationConfig( +data class GenerationConfig( val temperature: Float?, @SerialName("top_p") val topP: Float?, @SerialName("top_k") val topK: Int?, diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/server/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/server/Types.kt similarity index 74% rename from generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/server/Types.kt rename to common/src/main/kotlin/com/google/ai/client/generativeai/common/server/Types.kt index 21b8bd30..cd5a0d66 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/server/Types.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/server/Types.kt @@ -1,5 +1,5 @@ /* - * Copyright 2023 Google LLC + * Copyright 2024 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,34 +14,34 @@ * limitations under the License. */ -package com.google.ai.client.generativeai.internal.api.server +package com.google.ai.client.generativeai.common.server -import com.google.ai.client.generativeai.internal.api.shared.Content -import com.google.ai.client.generativeai.internal.api.shared.HarmCategory -import com.google.ai.client.generativeai.internal.util.FirstOrdinalSerializer +import com.google.ai.client.generativeai.common.shared.Content +import com.google.ai.client.generativeai.common.shared.HarmCategory +import com.google.ai.client.generativeai.common.util.FirstOrdinalSerializer import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.KSerializer import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable import kotlinx.serialization.json.JsonNames -internal object BlockReasonSerializer : +object BlockReasonSerializer : KSerializer by FirstOrdinalSerializer(BlockReason::class) -internal object HarmProbabilitySerializer : +object HarmProbabilitySerializer : KSerializer by FirstOrdinalSerializer(HarmProbability::class) -internal object FinishReasonSerializer : +object FinishReasonSerializer : KSerializer by FirstOrdinalSerializer(FinishReason::class) @Serializable -internal data class PromptFeedback( +data class PromptFeedback( val blockReason: BlockReason? = null, val safetyRatings: List? = null, ) @Serializable(BlockReasonSerializer::class) -internal enum class BlockReason { +enum class BlockReason { UNKNOWN, @SerialName("BLOCKED_REASON_UNSPECIFIED") UNSPECIFIED, SAFETY, @@ -49,7 +49,7 @@ internal enum class BlockReason { } @Serializable -internal data class Candidate( +data class Candidate( val content: Content? = null, val finishReason: FinishReason? = null, val safetyRatings: List? = null, @@ -57,12 +57,12 @@ internal data class Candidate( ) @Serializable -internal data class CitationMetadata +data class CitationMetadata @OptIn(ExperimentalSerializationApi::class) constructor(@JsonNames("citations") val citationSources: List) @Serializable -internal data class CitationSources( +data class CitationSources( val startIndex: Int, val endIndex: Int, val uri: String, @@ -70,14 +70,14 @@ internal data class CitationSources( ) @Serializable -internal data class SafetyRating( +data class SafetyRating( val category: HarmCategory, val probability: HarmProbability, val blocked: Boolean? = null // TODO(): any reason not to default to false? ) @Serializable(HarmProbabilitySerializer::class) -internal enum class HarmProbability { +enum class HarmProbability { UNKNOWN, @SerialName("HARM_PROBABILITY_UNSPECIFIED") UNSPECIFIED, NEGLIGIBLE, @@ -87,7 +87,7 @@ internal enum class HarmProbability { } @Serializable(FinishReasonSerializer::class) -internal enum class FinishReason { +enum class FinishReason { UNKNOWN, @SerialName("FINISH_REASON_UNSPECIFIED") UNSPECIFIED, STOP, @@ -98,7 +98,7 @@ internal enum class FinishReason { } @Serializable -internal data class GRpcError( +data class GRpcError( val code: Int, val message: String, ) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/shared/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt similarity index 71% rename from generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/shared/Types.kt rename to common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt index 6202a619..8fc641be 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/shared/Types.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt @@ -1,5 +1,5 @@ /* - * Copyright 2023 Google LLC + * Copyright 2024 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,9 +14,9 @@ * limitations under the License. */ -package com.google.ai.client.generativeai.internal.api.shared +package com.google.ai.client.generativeai.common.shared -import com.google.ai.client.generativeai.internal.util.FirstOrdinalSerializer +import com.google.ai.client.generativeai.common.util.FirstOrdinalSerializer import kotlinx.serialization.DeserializationStrategy import kotlinx.serialization.EncodeDefault import kotlinx.serialization.ExperimentalSerializationApi @@ -28,11 +28,11 @@ import kotlinx.serialization.json.JsonContentPolymorphicSerializer import kotlinx.serialization.json.JsonElement import kotlinx.serialization.json.jsonObject -internal object HarmCategorySerializer : +object HarmCategorySerializer : KSerializer by FirstOrdinalSerializer(HarmCategory::class) @Serializable(HarmCategorySerializer::class) -internal enum class HarmCategory { +enum class HarmCategory { UNKNOWN, @SerialName("HARM_CATEGORY_HARASSMENT") HARASSMENT, @SerialName("HARM_CATEGORY_HATE_SPEECH") HATE_SPEECH, @@ -44,25 +44,25 @@ typealias Base64 = String @ExperimentalSerializationApi @Serializable -internal data class Content(@EncodeDefault val role: String? = "user", val parts: List) +data class Content(@EncodeDefault val role: String? = "user", val parts: List) -@Serializable(PartSerializer::class) internal sealed interface Part +@Serializable(PartSerializer::class) sealed interface Part -@Serializable internal data class TextPart(val text: String) : Part +@Serializable data class TextPart(val text: String) : Part -@Serializable internal data class BlobPart(@SerialName("inline_data") val inlineData: Blob) : Part +@Serializable data class BlobPart(@SerialName("inline_data") val inlineData: Blob) : Part @Serializable -internal data class Blob( +data class Blob( @SerialName("mime_type") val mimeType: String, val data: Base64, ) @Serializable -internal data class SafetySetting(val category: HarmCategory, val threshold: HarmBlockThreshold) +data class SafetySetting(val category: HarmCategory, val threshold: HarmBlockThreshold) @Serializable -internal enum class HarmBlockThreshold { +enum class HarmBlockThreshold { @SerialName("HARM_BLOCK_THRESHOLD_UNSPECIFIED") UNSPECIFIED, BLOCK_LOW_AND_ABOVE, BLOCK_MEDIUM_AND_ABOVE, @@ -70,7 +70,7 @@ internal enum class HarmBlockThreshold { BLOCK_NONE, } -internal object PartSerializer : JsonContentPolymorphicSerializer(Part::class) { +object PartSerializer : JsonContentPolymorphicSerializer(Part::class) { override fun selectDeserializer(element: JsonElement): DeserializationStrategy { val jsonObject = element.jsonObject return when { diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/kotlin.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/kotlin.kt new file mode 100644 index 00000000..8f681542 --- /dev/null +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/kotlin.kt @@ -0,0 +1,41 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.common.util + +import java.lang.reflect.Field + +/** + * Removes the last character from the [StringBuilder]. + * + * If the StringBuilder is empty, calling this function will throw an [IndexOutOfBoundsException]. + * + * @return The [StringBuilder] used to make the call, for optional chaining. + * @throws IndexOutOfBoundsException if the StringBuilder is empty. + */ +internal fun StringBuilder.removeLast(): StringBuilder = + if (isEmpty()) throw IndexOutOfBoundsException("StringBuilder is empty.") + else deleteCharAt(length - 1) + +/** + * A variant of [getAnnotation][Field.getAnnotation] that provides implicit Kotlin support. + * + * Syntax sugar for: + * ``` + * getAnnotation(T::class.java) + * ``` + */ +internal inline fun Field.getAnnotation() = getAnnotation(T::class.java) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/ktor.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/ktor.kt similarity index 97% rename from generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/ktor.kt rename to common/src/main/kotlin/com/google/ai/client/generativeai/common/util/ktor.kt index ae3434e3..04eb18dc 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/ktor.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/ktor.kt @@ -1,5 +1,5 @@ /* - * Copyright 2023 Google LLC + * Copyright 2024 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ @file:Suppress("DEPRECATION") // a replacement for our purposes has not been published yet -package com.google.ai.client.generativeai.internal.util +package com.google.ai.client.generativeai.common.util import io.ktor.utils.io.ByteChannel import io.ktor.utils.io.ByteReadChannel diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/serialization.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/serialization.kt similarity index 88% rename from generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/serialization.kt rename to common/src/main/kotlin/com/google/ai/client/generativeai/common/util/serialization.kt index 76405fe3..d818c266 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/serialization.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/util/serialization.kt @@ -1,5 +1,5 @@ /* - * Copyright 2023 Google LLC + * Copyright 2024 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,10 +14,10 @@ * limitations under the License. */ -package com.google.ai.client.generativeai.internal.util +package com.google.ai.client.generativeai.common.util import android.util.Log -import com.google.ai.client.generativeai.type.SerializationException +import com.google.ai.client.generativeai.common.SerializationException import kotlin.reflect.KClass import kotlinx.serialization.KSerializer import kotlinx.serialization.SerialName @@ -34,8 +34,7 @@ import kotlinx.serialization.encoding.Encoder * When an unknown enum value is found, the enum itself will be logged to stderr with a message * about opening an issue on GitHub regarding the new enum value. */ -internal class FirstOrdinalSerializer>(private val enumClass: KClass) : - KSerializer { +class FirstOrdinalSerializer>(private val enumClass: KClass) : KSerializer { override val descriptor: SerialDescriptor = buildClassSerialDescriptor("FirstOrdinalSerializer") override fun deserialize(decoder: Decoder): T { @@ -71,7 +70,7 @@ internal class FirstOrdinalSerializer>(private val enumClass: KClass * By default an enum is serialized to its [name][Enum.name], and can be overwritten by providing a * [SerialName] annotation. */ -internal val > T.serialName: String +val > T.serialName: String get() = declaringJavaClass.getField(name).getAnnotation()?.value ?: name /** @@ -80,5 +79,5 @@ internal val > T.serialName: String * @throws SerializationException if the class is not a valid enum. Beyond runtime emily magic, this * shouldn't really be possible. */ -internal fun > KClass.enumValues(): Array = +fun > KClass.enumValues(): Array = java.enumConstants ?: throw SerializationException("$simpleName is not a valid enum type.") diff --git a/common/src/test/java/com/google/ai/client/generativeai/common/GenerativeModelTests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/GenerativeModelTests.kt new file mode 100644 index 00000000..6269b673 --- /dev/null +++ b/common/src/test/java/com/google/ai/client/generativeai/common/GenerativeModelTests.kt @@ -0,0 +1,112 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.common + +import com.google.ai.client.generativeai.common.shared.Content +import com.google.ai.client.generativeai.common.shared.TextPart +import com.google.ai.client.generativeai.common.util.commonTest +import com.google.ai.client.generativeai.common.util.createResponses +import com.google.ai.client.generativeai.common.util.doBlocking +import com.google.ai.client.generativeai.common.util.prepareStreamingResponse +import io.kotest.assertions.throwables.shouldThrow +import io.kotest.matchers.shouldBe +import io.kotest.matchers.string.shouldContain +import io.ktor.client.engine.mock.MockEngine +import io.ktor.client.engine.mock.respond +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpStatusCode +import io.ktor.http.headersOf +import io.ktor.utils.io.ByteChannel +import io.ktor.utils.io.close +import io.ktor.utils.io.writeFully +import kotlin.time.Duration.Companion.seconds +import kotlinx.coroutines.withTimeout +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +internal class GenerativeModelTests { + private val testTimeout = 5.seconds + + @Test + fun `(generateContentStream) emits responses as they come in`() = commonTest { + val response = createResponses("The", " world", " is", " a", " beautiful", " place!") + val bytes = prepareStreamingResponse(response) + + bytes.forEach { channel.writeFully(it) } + val responses = apiController.generateContentStream(textGenerateContentRequest("test")) + + withTimeout(testTimeout) { + responses.collect { + it.candidates?.isEmpty() shouldBe false + channel.close() + } + } + } + + @Test + fun `(generateContent) respects a custom timeout`() = + commonTest(requestOptions = RequestOptions(2.seconds)) { + shouldThrow { + withTimeout(testTimeout) { + apiController.generateContent(textGenerateContentRequest("test")) + } + } + } +} + +@RunWith(Parameterized::class) +internal class ModelNamingTests(private val modelName: String, private val actualName: String) { + + @Test + fun `request should include right model name`() = doBlocking { + val channel = ByteChannel(autoFlush = true) + val mockEngine = MockEngine { + respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) + } + prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) } + val controller = APIController("super_cool_test_key", modelName, RequestOptions(), mockEngine) + + withTimeout(5.seconds) { + controller.generateContentStream(textGenerateContentRequest("cats")).collect { + it.candidates?.isEmpty() shouldBe false + channel.close() + } + } + + mockEngine.requestHistory.first().url.encodedPath shouldContain actualName + } + + companion object { + @JvmStatic + @Parameterized.Parameters + fun data() = + listOf( + arrayOf("gemini-pro", "models/gemini-pro"), + arrayOf("x/gemini-pro", "x/gemini-pro"), + arrayOf("models/gemini-pro", "models/gemini-pro"), + arrayOf("/modelname", "/modelname"), + arrayOf("modifiedNaming/mymodel", "modifiedNaming/mymodel"), + ) + } +} + +fun textGenerateContentRequest(prompt: String) = + GenerateContentRequest( + model = "unused", + contents = listOf(Content(parts = listOf(TextPart(prompt)))) + ) diff --git a/generativeai/src/test/java/com/google/ai/client/generativeai/StreamingSnapshotTests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/StreamingSnapshotTests.kt similarity index 56% rename from generativeai/src/test/java/com/google/ai/client/generativeai/StreamingSnapshotTests.kt rename to common/src/test/java/com/google/ai/client/generativeai/common/StreamingSnapshotTests.kt index 4615bcdc..7f151320 100644 --- a/generativeai/src/test/java/com/google/ai/client/generativeai/StreamingSnapshotTests.kt +++ b/common/src/test/java/com/google/ai/client/generativeai/common/StreamingSnapshotTests.kt @@ -1,5 +1,5 @@ /* - * Copyright 2023 Google LLC + * Copyright 2024 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,19 +14,17 @@ * limitations under the License. */ -package com.google.ai.client.generativeai - -import com.google.ai.client.generativeai.type.BlockReason -import com.google.ai.client.generativeai.type.FinishReason -import com.google.ai.client.generativeai.type.HarmCategory -import com.google.ai.client.generativeai.type.InvalidAPIKeyException -import com.google.ai.client.generativeai.type.PromptBlockedException -import com.google.ai.client.generativeai.type.ResponseStoppedException -import com.google.ai.client.generativeai.type.SerializationException -import com.google.ai.client.generativeai.type.ServerException -import com.google.ai.client.generativeai.util.goldenStreamingFile +package com.google.ai.client.generativeai.common + +import com.google.ai.client.generativeai.common.server.BlockReason +import com.google.ai.client.generativeai.common.server.FinishReason +import com.google.ai.client.generativeai.common.shared.HarmCategory +import com.google.ai.client.generativeai.common.shared.TextPart +import com.google.ai.client.generativeai.common.util.goldenStreamingFile import io.kotest.assertions.throwables.shouldThrow +import io.kotest.matchers.nulls.shouldNotBeNull import io.kotest.matchers.shouldBe +import io.kotest.matchers.string.shouldContain import io.ktor.http.HttpStatusCode import kotlin.time.Duration.Companion.seconds import kotlinx.coroutines.flow.collect @@ -41,29 +39,29 @@ internal class StreamingSnapshotTests { @Test fun `short reply`() = goldenStreamingFile("success-basic-reply-short.txt") { - val responses = model.generateContentStream() + val responses = apiController.generateContentStream(textGenerateContentRequest("prompt")) withTimeout(testTimeout) { val responseList = responses.toList() responseList.isEmpty() shouldBe false - responseList.first().candidates.first().finishReason shouldBe FinishReason.STOP - responseList.first().candidates.first().content.parts.isEmpty() shouldBe false - responseList.first().candidates.first().safetyRatings.isEmpty() shouldBe false + responseList.first().candidates?.first()?.finishReason shouldBe FinishReason.STOP + responseList.first().candidates?.first()?.content?.parts?.isEmpty() shouldBe false + responseList.first().candidates?.first()?.safetyRatings?.isEmpty() shouldBe false } } @Test fun `long reply`() = goldenStreamingFile("success-basic-reply-long.txt") { - val responses = model.generateContentStream() + val responses = apiController.generateContentStream(textGenerateContentRequest("prompt")) withTimeout(testTimeout) { val responseList = responses.toList() responseList.isEmpty() shouldBe false responseList.forEach { - it.candidates.first().finishReason shouldBe FinishReason.STOP - it.candidates.first().content.parts.isEmpty() shouldBe false - it.candidates.first().safetyRatings.isEmpty() shouldBe false + it.candidates?.first()?.finishReason shouldBe FinishReason.STOP + it.candidates?.first()?.content?.parts?.isEmpty() shouldBe false + it.candidates?.first()?.safetyRatings?.isEmpty() shouldBe false } } } @@ -71,11 +69,13 @@ internal class StreamingSnapshotTests { @Test fun `unknown enum`() = goldenStreamingFile("success-unknown-enum.txt") { - val responses = model.generateContentStream() + val responses = apiController.generateContentStream(textGenerateContentRequest("prompt")) withTimeout(testTimeout) { responses.first { - it.candidates.any { it.safetyRatings.any { it.category == HarmCategory.UNKNOWN } } + it.candidates?.any { + it.safetyRatings?.any { it.category == HarmCategory.UNKNOWN } ?: false + } ?: false } } } @@ -83,20 +83,22 @@ internal class StreamingSnapshotTests { @Test fun `quotes escaped`() = goldenStreamingFile("success-quotes-escaped.txt") { - val responses = model.generateContentStream() + val responses = apiController.generateContentStream(textGenerateContentRequest("prompt")) withTimeout(testTimeout) { val responseList = responses.toList() responseList.isEmpty() shouldBe false - responseList.first().text!!.contains("\"") + val part = responseList.first().candidates?.first()?.content?.parts?.first() as? TextPart + part.shouldNotBeNull() + part.text shouldContain "\"" } } @Test fun `prompt blocked for safety`() = goldenStreamingFile("failure-prompt-blocked-safety.txt") { - val responses = model.generateContentStream() + val responses = apiController.generateContentStream(textGenerateContentRequest("prompt")) withTimeout(testTimeout) { val exception = shouldThrow { responses.collect() } @@ -107,7 +109,7 @@ internal class StreamingSnapshotTests { @Test fun `empty content`() = goldenStreamingFile("failure-empty-content.txt") { - val responses = model.generateContentStream() + val responses = apiController.generateContentStream(textGenerateContentRequest("prompt")) withTimeout(testTimeout) { shouldThrow { responses.collect() } } } @@ -115,7 +117,7 @@ internal class StreamingSnapshotTests { @Test fun `http errors`() = goldenStreamingFile("failure-http-error.txt", HttpStatusCode.PreconditionFailed) { - val responses = model.generateContentStream() + val responses = apiController.generateContentStream(textGenerateContentRequest("prompt")) withTimeout(testTimeout) { shouldThrow { responses.collect() } } } @@ -123,51 +125,57 @@ internal class StreamingSnapshotTests { @Test fun `stopped for safety`() = goldenStreamingFile("failure-finish-reason-safety.txt") { - val responses = model.generateContentStream() + val responses = apiController.generateContentStream(textGenerateContentRequest("prompt")) withTimeout(testTimeout) { val exception = shouldThrow { responses.collect() } - exception.response.candidates.first().finishReason shouldBe FinishReason.SAFETY + exception.response.candidates?.first()?.finishReason shouldBe FinishReason.SAFETY } } @Test fun `citation parsed correctly`() = goldenStreamingFile("success-citations.txt") { - val responses = model.generateContentStream() + val responses = apiController.generateContentStream(textGenerateContentRequest("prompt")) withTimeout(testTimeout) { val responseList = responses.toList() - responseList.any { it.candidates.any { it.citationMetadata.isNotEmpty() } } shouldBe true + responseList.any { + it.candidates?.any { it.citationMetadata?.citationSources?.isNotEmpty() ?: false } + ?: false + } shouldBe true } } @Test fun `citation returns correctly when using alternative name`() = goldenStreamingFile("success-citations-altname.txt") { - val responses = model.generateContentStream() + val responses = apiController.generateContentStream(textGenerateContentRequest("prompt")) withTimeout(testTimeout) { val responseList = responses.toList() - responseList.any { it.candidates.any { it.citationMetadata.isNotEmpty() } } shouldBe true + responseList.any { + it.candidates?.any { it.citationMetadata?.citationSources?.isNotEmpty() ?: false } + ?: false + } shouldBe true } } @Test fun `stopped for recitation`() = goldenStreamingFile("failure-recitation-no-content.txt") { - val responses = model.generateContentStream() + val responses = apiController.generateContentStream(textGenerateContentRequest("prompt")) withTimeout(testTimeout) { val exception = shouldThrow { responses.collect() } - exception.response.candidates.first().finishReason shouldBe FinishReason.RECITATION + exception.response.candidates?.first()?.finishReason shouldBe FinishReason.RECITATION } } @Test fun `image rejected`() = goldenStreamingFile("failure-image-rejected.txt", HttpStatusCode.BadRequest) { - val responses = model.generateContentStream() + val responses = apiController.generateContentStream(textGenerateContentRequest("prompt")) withTimeout(testTimeout) { shouldThrow { responses.collect() } } } @@ -175,7 +183,7 @@ internal class StreamingSnapshotTests { @Test fun `unknown model`() = goldenStreamingFile("failure-unknown-model.txt", HttpStatusCode.NotFound) { - val responses = model.generateContentStream() + val responses = apiController.generateContentStream(textGenerateContentRequest("prompt")) withTimeout(testTimeout) { shouldThrow { responses.collect() } } } @@ -183,7 +191,7 @@ internal class StreamingSnapshotTests { @Test fun `invalid api key`() = goldenStreamingFile("failure-api-key.txt", HttpStatusCode.BadRequest) { - val responses = model.generateContentStream() + val responses = apiController.generateContentStream(textGenerateContentRequest("prompt")) withTimeout(testTimeout) { shouldThrow { responses.collect() } } } diff --git a/common/src/test/java/com/google/ai/client/generativeai/common/UnarySnapshotTests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/UnarySnapshotTests.kt new file mode 100644 index 00000000..a05284da --- /dev/null +++ b/common/src/test/java/com/google/ai/client/generativeai/common/UnarySnapshotTests.kt @@ -0,0 +1,195 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.common + +import com.google.ai.client.generativeai.common.server.BlockReason +import com.google.ai.client.generativeai.common.server.FinishReason +import com.google.ai.client.generativeai.common.shared.HarmCategory +import com.google.ai.client.generativeai.common.util.goldenUnaryFile +import io.kotest.assertions.throwables.shouldThrow +import io.kotest.matchers.should +import io.kotest.matchers.shouldBe +import io.ktor.http.HttpStatusCode +import kotlin.time.Duration.Companion.seconds +import kotlinx.coroutines.withTimeout +import org.junit.Test + +internal class UnarySnapshotTests { + private val testTimeout = 5.seconds + + @Test + fun `short reply`() = + goldenUnaryFile("success-basic-reply-short.json") { + withTimeout(testTimeout) { + val response = apiController.generateContent(textGenerateContentRequest("prompt")) + + response.candidates?.isEmpty() shouldBe false + response.candidates?.first()?.finishReason shouldBe FinishReason.STOP + response.candidates?.first()?.content?.parts?.isEmpty() shouldBe false + response.candidates?.first()?.safetyRatings?.isEmpty() shouldBe false + } + } + + @Test + fun `long reply`() = + goldenUnaryFile("success-basic-reply-long.json") { + withTimeout(testTimeout) { + val response = apiController.generateContent(textGenerateContentRequest("prompt")) + + response.candidates?.isEmpty() shouldBe false + response.candidates?.first()?.finishReason shouldBe FinishReason.STOP + response.candidates?.first()?.content?.parts?.isEmpty() shouldBe false + response.candidates?.first()?.safetyRatings?.isEmpty() shouldBe false + } + } + + @Test + fun `unknown enum`() = + goldenUnaryFile("success-unknown-enum.json") { + withTimeout(testTimeout) { + val response = apiController.generateContent(textGenerateContentRequest("prompt")) + + response.candidates?.first { + it.safetyRatings?.any { it.category == HarmCategory.UNKNOWN } ?: false + } + } + } + + @Test + fun `prompt blocked for safety`() = + goldenUnaryFile("failure-prompt-blocked-safety.json") { + withTimeout(testTimeout) { + shouldThrow { + apiController.generateContent(textGenerateContentRequest("prompt")) + } should { it.response.promptFeedback?.blockReason shouldBe BlockReason.SAFETY } + } + } + + @Test + fun `empty content`() = + goldenUnaryFile("failure-empty-content.json") { + withTimeout(testTimeout) { + shouldThrow { + apiController.generateContent(textGenerateContentRequest("prompt")) + } + } + } + + @Test + fun `http error`() = + goldenUnaryFile("failure-http-error.json", HttpStatusCode.PreconditionFailed) { + withTimeout(testTimeout) { + shouldThrow { + apiController.generateContent(textGenerateContentRequest("prompt")) + } + } + } + + @Test + fun `user location error`() = + goldenUnaryFile("failure-unsupported-user-location.json", HttpStatusCode.PreconditionFailed) { + withTimeout(testTimeout) { + shouldThrow { + apiController.generateContent(textGenerateContentRequest("prompt")) + } + } + } + + @Test + fun `stopped for safety`() = + goldenUnaryFile("failure-finish-reason-safety.json") { + withTimeout(testTimeout) { + val exception = + shouldThrow { + apiController.generateContent(textGenerateContentRequest("prompt")) + } + exception.response.candidates?.first()?.finishReason shouldBe FinishReason.SAFETY + } + } + + @Test + fun `citation returns correctly`() = + goldenUnaryFile("success-citations.json") { + withTimeout(testTimeout) { + val response = apiController.generateContent(textGenerateContentRequest("prompt")) + + response.candidates?.isEmpty() shouldBe false + response.candidates?.first()?.citationMetadata?.citationSources?.isNotEmpty() shouldBe true + } + } + + @Test + fun `citation returns correctly when using alternative name`() = + goldenUnaryFile("success-citations-altname.json") { + withTimeout(testTimeout) { + val response = apiController.generateContent(textGenerateContentRequest("prompt")) + + response.candidates?.isEmpty() shouldBe false + response.candidates?.first()?.citationMetadata?.citationSources?.isNotEmpty() shouldBe true + } + } + + @Test + fun `invalid response`() = + goldenUnaryFile("failure-invalid-response.json") { + withTimeout(testTimeout) { + shouldThrow { + apiController.generateContent(textGenerateContentRequest("prompt")) + } + } + } + + @Test + fun `malformed content`() = + goldenUnaryFile("failure-malformed-content.json") { + withTimeout(testTimeout) { + shouldThrow { + apiController.generateContent(textGenerateContentRequest("prompt")) + } + } + } + + @Test + fun `invalid api key`() = + goldenUnaryFile("failure-api-key.json", HttpStatusCode.BadRequest) { + withTimeout(testTimeout) { + shouldThrow { + apiController.generateContent(textGenerateContentRequest("prompt")) + } + } + } + + @Test + fun `image rejected`() = + goldenUnaryFile("failure-image-rejected.json", HttpStatusCode.BadRequest) { + withTimeout(testTimeout) { + shouldThrow { + apiController.generateContent(textGenerateContentRequest("prompt")) + } + } + } + + @Test + fun `unknown model`() = + goldenUnaryFile("failure-unknown-model.json", HttpStatusCode.NotFound) { + withTimeout(testTimeout) { + shouldThrow { + apiController.generateContent(textGenerateContentRequest("prompt")) + } + } + } +} diff --git a/generativeai/src/test/java/com/google/ai/client/generativeai/util/kotlin.kt b/common/src/test/java/com/google/ai/client/generativeai/common/util/kotlin.kt similarity index 93% rename from generativeai/src/test/java/com/google/ai/client/generativeai/util/kotlin.kt rename to common/src/test/java/com/google/ai/client/generativeai/common/util/kotlin.kt index 259f72eb..c7ebf102 100644 --- a/generativeai/src/test/java/com/google/ai/client/generativeai/util/kotlin.kt +++ b/common/src/test/java/com/google/ai/client/generativeai/common/util/kotlin.kt @@ -1,5 +1,5 @@ /* - * Copyright 2023 Google LLC + * Copyright 2024 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.google.ai.client.generativeai.util +package com.google.ai.client.generativeai.common.util import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.runBlocking diff --git a/generativeai/src/test/java/com/google/ai/client/generativeai/util/tests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/util/tests.kt similarity index 78% rename from generativeai/src/test/java/com/google/ai/client/generativeai/util/tests.kt rename to common/src/test/java/com/google/ai/client/generativeai/common/util/tests.kt index 90707336..d753a9bf 100644 --- a/generativeai/src/test/java/com/google/ai/client/generativeai/util/tests.kt +++ b/common/src/test/java/com/google/ai/client/generativeai/common/util/tests.kt @@ -1,5 +1,5 @@ /* - * Copyright 2023 Google LLC + * Copyright 2024 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,19 +16,16 @@ @file:Suppress("DEPRECATION") // a replacement for our purposes has not been published yet -package com.google.ai.client.generativeai.util - -import com.google.ai.client.generativeai.GenerativeModel -import com.google.ai.client.generativeai.internal.api.APIController -import com.google.ai.client.generativeai.internal.api.GenerateContentRequest -import com.google.ai.client.generativeai.internal.api.GenerateContentResponse -import com.google.ai.client.generativeai.internal.api.JSON -import com.google.ai.client.generativeai.internal.api.server.Candidate -import com.google.ai.client.generativeai.internal.api.shared.Content -import com.google.ai.client.generativeai.internal.api.shared.TextPart -import com.google.ai.client.generativeai.internal.util.SSE_SEPARATOR -import com.google.ai.client.generativeai.internal.util.send -import com.google.ai.client.generativeai.type.RequestOptions +package com.google.ai.client.generativeai.common.util + +import com.google.ai.client.generativeai.common.APIController +import com.google.ai.client.generativeai.common.GenerateContentRequest +import com.google.ai.client.generativeai.common.GenerateContentResponse +import com.google.ai.client.generativeai.common.JSON +import com.google.ai.client.generativeai.common.RequestOptions +import com.google.ai.client.generativeai.common.server.Candidate +import com.google.ai.client.generativeai.common.shared.Content +import com.google.ai.client.generativeai.common.shared.TextPart import io.ktor.client.engine.mock.MockEngine import io.ktor.client.engine.mock.respond import io.ktor.http.HttpHeaders @@ -65,11 +62,11 @@ internal fun createResponses(vararg text: String): List * Wrapper around common instances needed in tests. * * @param channel A [ByteChannel] for sending responses through the mock HTTP engine - * @param model A [GenerativeModel] that consumes the [channel] + * @param apiController A [APIController] that consumes the [channel] * @see commonTest * @see send */ -internal data class CommonTestScope(val channel: ByteChannel, val model: GenerativeModel) +internal data class CommonTestScope(val channel: ByteChannel, val apiController: APIController) /** A test that runs under a [CommonTestScope]. */ internal typealias CommonTest = suspend CommonTestScope.() -> Unit @@ -107,23 +104,10 @@ internal fun commonTest( val mockEngine = MockEngine { respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json")) } - val model = createGenerativeModel("gemini-pro", "super_cool_test_key", requestOptions, mockEngine) - CommonTestScope(channel, model).block() + val apiController = APIController("super_cool_test_key", "gemini-pro", requestOptions, mockEngine) + CommonTestScope(channel, apiController).block() } -/** Simple wrapper that guarantees the model and APIController are created using the same data */ -internal fun createGenerativeModel( - name: String, - apikey: String, - requestOptions: RequestOptions = RequestOptions(), - engine: MockEngine -) = - GenerativeModel( - name, - apikey, - controller = APIController("super_cool_test_key", name, requestOptions, engine) - ) - /** * A variant of [commonTest] for performing *streaming-based* snapshot tests. * diff --git a/generativeai/src/test/resources/golden-files/streaming/failure-api-key.txt b/common/src/test/resources/golden-files/streaming/failure-api-key.txt similarity index 100% rename from generativeai/src/test/resources/golden-files/streaming/failure-api-key.txt rename to common/src/test/resources/golden-files/streaming/failure-api-key.txt diff --git a/generativeai/src/test/resources/golden-files/streaming/failure-empty-content.txt b/common/src/test/resources/golden-files/streaming/failure-empty-content.txt similarity index 100% rename from generativeai/src/test/resources/golden-files/streaming/failure-empty-content.txt rename to common/src/test/resources/golden-files/streaming/failure-empty-content.txt diff --git a/generativeai/src/test/resources/golden-files/streaming/failure-finish-reason-safety.txt b/common/src/test/resources/golden-files/streaming/failure-finish-reason-safety.txt similarity index 100% rename from generativeai/src/test/resources/golden-files/streaming/failure-finish-reason-safety.txt rename to common/src/test/resources/golden-files/streaming/failure-finish-reason-safety.txt diff --git a/generativeai/src/test/resources/golden-files/streaming/failure-http-error.txt b/common/src/test/resources/golden-files/streaming/failure-http-error.txt similarity index 100% rename from generativeai/src/test/resources/golden-files/streaming/failure-http-error.txt rename to common/src/test/resources/golden-files/streaming/failure-http-error.txt diff --git a/generativeai/src/test/resources/golden-files/streaming/failure-image-rejected.txt b/common/src/test/resources/golden-files/streaming/failure-image-rejected.txt similarity index 100% rename from generativeai/src/test/resources/golden-files/streaming/failure-image-rejected.txt rename to common/src/test/resources/golden-files/streaming/failure-image-rejected.txt diff --git a/generativeai/src/test/resources/golden-files/streaming/failure-prompt-blocked-safety.txt b/common/src/test/resources/golden-files/streaming/failure-prompt-blocked-safety.txt similarity index 100% rename from generativeai/src/test/resources/golden-files/streaming/failure-prompt-blocked-safety.txt rename to common/src/test/resources/golden-files/streaming/failure-prompt-blocked-safety.txt diff --git a/generativeai/src/test/resources/golden-files/streaming/failure-recitation-no-content.txt b/common/src/test/resources/golden-files/streaming/failure-recitation-no-content.txt similarity index 100% rename from generativeai/src/test/resources/golden-files/streaming/failure-recitation-no-content.txt rename to common/src/test/resources/golden-files/streaming/failure-recitation-no-content.txt diff --git a/generativeai/src/test/resources/golden-files/streaming/failure-unknown-model.txt b/common/src/test/resources/golden-files/streaming/failure-unknown-model.txt similarity index 100% rename from generativeai/src/test/resources/golden-files/streaming/failure-unknown-model.txt rename to common/src/test/resources/golden-files/streaming/failure-unknown-model.txt diff --git a/generativeai/src/test/resources/golden-files/streaming/success-basic-reply-long.txt b/common/src/test/resources/golden-files/streaming/success-basic-reply-long.txt similarity index 100% rename from generativeai/src/test/resources/golden-files/streaming/success-basic-reply-long.txt rename to common/src/test/resources/golden-files/streaming/success-basic-reply-long.txt diff --git a/generativeai/src/test/resources/golden-files/streaming/success-basic-reply-short.txt b/common/src/test/resources/golden-files/streaming/success-basic-reply-short.txt similarity index 100% rename from generativeai/src/test/resources/golden-files/streaming/success-basic-reply-short.txt rename to common/src/test/resources/golden-files/streaming/success-basic-reply-short.txt diff --git a/generativeai/src/test/resources/golden-files/streaming/success-citations-altname.txt b/common/src/test/resources/golden-files/streaming/success-citations-altname.txt similarity index 100% rename from generativeai/src/test/resources/golden-files/streaming/success-citations-altname.txt rename to common/src/test/resources/golden-files/streaming/success-citations-altname.txt diff --git a/generativeai/src/test/resources/golden-files/streaming/success-citations.txt b/common/src/test/resources/golden-files/streaming/success-citations.txt similarity index 100% rename from generativeai/src/test/resources/golden-files/streaming/success-citations.txt rename to common/src/test/resources/golden-files/streaming/success-citations.txt diff --git a/generativeai/src/test/resources/golden-files/streaming/success-quotes-escaped.txt b/common/src/test/resources/golden-files/streaming/success-quotes-escaped.txt similarity index 64% rename from generativeai/src/test/resources/golden-files/streaming/success-quotes-escaped.txt rename to common/src/test/resources/golden-files/streaming/success-quotes-escaped.txt index 0c48e4c4..ef71be29 100644 --- a/generativeai/src/test/resources/golden-files/streaming/success-quotes-escaped.txt +++ b/common/src/test/resources/golden-files/streaming/success-quotes-escaped.txt @@ -1,4 +1,4 @@ -data: {"candidates": [{"content": {"parts": [{"text": " Pineapples and bananas are two different types of fruit. Pineapples grow on a"}]},"index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_TOXICITY","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_SEXUAL","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_VIOLENCE","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DEROGATORY","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS","probability": "NEGLIGIBLE"}]}],"promptFeedback": {"safetyRatings": [{"category": "HARM_CATEGORY_TOXICITY","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_SEXUAL","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_VIOLENCE","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DEROGATORY","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS","probability": "NEGLIGIBLE"}]}} +data: {"candidates": [{"content": {"parts": [{"text": " Pineapples and \"bananas\" are two different types of fruit. Pineapples grow on a"}]},"index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_TOXICITY","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_SEXUAL","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_VIOLENCE","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DEROGATORY","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS","probability": "NEGLIGIBLE"}]}],"promptFeedback": {"safetyRatings": [{"category": "HARM_CATEGORY_TOXICITY","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_SEXUAL","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_VIOLENCE","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DEROGATORY","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS","probability": "NEGLIGIBLE"}]}} data: {"candidates": [{"content": {"parts": [{"text": " tropical plant with a rosette of long, pointed leaves. Bananas grow on a herbaceous"}]},"index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_TOXICITY","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_SEXUAL","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_VIOLENCE","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DEROGATORY","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS","probability": "NEGLIGIBLE"}]}]} diff --git a/generativeai/src/test/resources/golden-files/streaming/success-unknown-enum.txt b/common/src/test/resources/golden-files/streaming/success-unknown-enum.txt similarity index 100% rename from generativeai/src/test/resources/golden-files/streaming/success-unknown-enum.txt rename to common/src/test/resources/golden-files/streaming/success-unknown-enum.txt diff --git a/generativeai/src/test/resources/golden-files/unary/failure-api-key.json b/common/src/test/resources/golden-files/unary/failure-api-key.json similarity index 100% rename from generativeai/src/test/resources/golden-files/unary/failure-api-key.json rename to common/src/test/resources/golden-files/unary/failure-api-key.json diff --git a/generativeai/src/test/resources/golden-files/unary/failure-empty-content.json b/common/src/test/resources/golden-files/unary/failure-empty-content.json similarity index 100% rename from generativeai/src/test/resources/golden-files/unary/failure-empty-content.json rename to common/src/test/resources/golden-files/unary/failure-empty-content.json diff --git a/generativeai/src/test/resources/golden-files/unary/failure-finish-reason-safety.json b/common/src/test/resources/golden-files/unary/failure-finish-reason-safety.json similarity index 100% rename from generativeai/src/test/resources/golden-files/unary/failure-finish-reason-safety.json rename to common/src/test/resources/golden-files/unary/failure-finish-reason-safety.json diff --git a/generativeai/src/test/resources/golden-files/unary/failure-http-error.json b/common/src/test/resources/golden-files/unary/failure-http-error.json similarity index 100% rename from generativeai/src/test/resources/golden-files/unary/failure-http-error.json rename to common/src/test/resources/golden-files/unary/failure-http-error.json diff --git a/generativeai/src/test/resources/golden-files/unary/failure-image-rejected.json b/common/src/test/resources/golden-files/unary/failure-image-rejected.json similarity index 100% rename from generativeai/src/test/resources/golden-files/unary/failure-image-rejected.json rename to common/src/test/resources/golden-files/unary/failure-image-rejected.json diff --git a/generativeai/src/test/resources/golden-files/unary/failure-invalid-response.json b/common/src/test/resources/golden-files/unary/failure-invalid-response.json similarity index 100% rename from generativeai/src/test/resources/golden-files/unary/failure-invalid-response.json rename to common/src/test/resources/golden-files/unary/failure-invalid-response.json diff --git a/generativeai/src/test/resources/golden-files/unary/failure-malformed-content.json b/common/src/test/resources/golden-files/unary/failure-malformed-content.json similarity index 100% rename from generativeai/src/test/resources/golden-files/unary/failure-malformed-content.json rename to common/src/test/resources/golden-files/unary/failure-malformed-content.json diff --git a/generativeai/src/test/resources/golden-files/unary/failure-prompt-blocked-safety.json b/common/src/test/resources/golden-files/unary/failure-prompt-blocked-safety.json similarity index 100% rename from generativeai/src/test/resources/golden-files/unary/failure-prompt-blocked-safety.json rename to common/src/test/resources/golden-files/unary/failure-prompt-blocked-safety.json diff --git a/generativeai/src/test/resources/golden-files/unary/failure-unknown-model.json b/common/src/test/resources/golden-files/unary/failure-unknown-model.json similarity index 100% rename from generativeai/src/test/resources/golden-files/unary/failure-unknown-model.json rename to common/src/test/resources/golden-files/unary/failure-unknown-model.json diff --git a/generativeai/src/test/resources/golden-files/unary/failure-unsupported-user-location.json b/common/src/test/resources/golden-files/unary/failure-unsupported-user-location.json similarity index 100% rename from generativeai/src/test/resources/golden-files/unary/failure-unsupported-user-location.json rename to common/src/test/resources/golden-files/unary/failure-unsupported-user-location.json diff --git a/generativeai/src/test/resources/golden-files/unary/success-basic-reply-long.json b/common/src/test/resources/golden-files/unary/success-basic-reply-long.json similarity index 100% rename from generativeai/src/test/resources/golden-files/unary/success-basic-reply-long.json rename to common/src/test/resources/golden-files/unary/success-basic-reply-long.json diff --git a/generativeai/src/test/resources/golden-files/unary/success-basic-reply-short.json b/common/src/test/resources/golden-files/unary/success-basic-reply-short.json similarity index 100% rename from generativeai/src/test/resources/golden-files/unary/success-basic-reply-short.json rename to common/src/test/resources/golden-files/unary/success-basic-reply-short.json diff --git a/generativeai/src/test/resources/golden-files/unary/success-citations-altname.json b/common/src/test/resources/golden-files/unary/success-citations-altname.json similarity index 100% rename from generativeai/src/test/resources/golden-files/unary/success-citations-altname.json rename to common/src/test/resources/golden-files/unary/success-citations-altname.json diff --git a/generativeai/src/test/resources/golden-files/unary/success-citations.json b/common/src/test/resources/golden-files/unary/success-citations.json similarity index 100% rename from generativeai/src/test/resources/golden-files/unary/success-citations.json rename to common/src/test/resources/golden-files/unary/success-citations.json diff --git a/generativeai/src/test/resources/golden-files/unary/success-quote-reply.json b/common/src/test/resources/golden-files/unary/success-quote-reply.json similarity index 100% rename from generativeai/src/test/resources/golden-files/unary/success-quote-reply.json rename to common/src/test/resources/golden-files/unary/success-quote-reply.json diff --git a/generativeai/src/test/resources/golden-files/unary/success-unknown-enum.json b/common/src/test/resources/golden-files/unary/success-unknown-enum.json similarity index 100% rename from generativeai/src/test/resources/golden-files/unary/success-unknown-enum.json rename to common/src/test/resources/golden-files/unary/success-unknown-enum.json diff --git a/generativeai/build.gradle.kts b/generativeai/build.gradle.kts index 39d5f72c..bac2aefa 100644 --- a/generativeai/build.gradle.kts +++ b/generativeai/build.gradle.kts @@ -73,15 +73,8 @@ android { } dependencies { - val ktorVersion = "2.3.2" + implementation(project(":common")) - implementation("io.ktor:ktor-client-okhttp:$ktorVersion") - implementation("io.ktor:ktor-client-core:$ktorVersion") - implementation("io.ktor:ktor-client-content-negotiation:$ktorVersion") - implementation("io.ktor:ktor-serialization-kotlinx-json:$ktorVersion") - implementation("io.ktor:ktor-client-logging:$ktorVersion") - - implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1") implementation("androidx.core:core-ktx:1.12.0") implementation("org.slf4j:slf4j-nop:2.0.9") implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.7.3") @@ -92,10 +85,9 @@ dependencies { implementation("androidx.concurrent:concurrent-futures:1.2.0-alpha02") implementation("androidx.concurrent:concurrent-futures-ktx:1.2.0-alpha02") testImplementation("junit:junit:4.13.2") - testImplementation("io.kotest:kotest-assertions-core:4.0.7") - testImplementation("io.kotest:kotest-assertions-jvm:4.0.7") - testImplementation("io.kotest:kotest-assertions-json:4.0.7") - testImplementation("io.ktor:ktor-client-mock:$ktorVersion") + testImplementation("io.kotest:kotest-assertions-core:5.5.5") + testImplementation("io.kotest:kotest-assertions-core-jvm:5.5.5") + testImplementation("io.mockk:mockk:1.13.10") androidTestImplementation("androidx.test.ext:junit:1.1.5") androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1") diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt index 77c86cce..690f5be4 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt @@ -17,9 +17,9 @@ package com.google.ai.client.generativeai import android.graphics.Bitmap -import com.google.ai.client.generativeai.internal.api.APIController -import com.google.ai.client.generativeai.internal.api.CountTokensRequest -import com.google.ai.client.generativeai.internal.api.GenerateContentRequest +import com.google.ai.client.generativeai.common.APIController +import com.google.ai.client.generativeai.common.CountTokensRequest +import com.google.ai.client.generativeai.common.GenerateContentRequest import com.google.ai.client.generativeai.internal.util.toInternal import com.google.ai.client.generativeai.internal.util.toPublic import com.google.ai.client.generativeai.type.Content @@ -71,7 +71,7 @@ internal constructor( generationConfig, safetySettings, requestOptions, - APIController(apiKey, modelName, requestOptions) + APIController(apiKey, modelName, requestOptions.toInternal()) ) /** @@ -97,8 +97,8 @@ internal constructor( fun generateContentStream(vararg prompt: Content): Flow = controller .generateContentStream(constructRequest(*prompt)) - .map { it.toPublic().validate() } .catch { throw GoogleGenerativeAIException.from(it) } + .map { it.toPublic().validate() } /** * Generates a response from the backend with the provided text represented [Content]. diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt index df0d5285..e4b236b7 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt @@ -19,24 +19,25 @@ package com.google.ai.client.generativeai.internal.util import android.graphics.Bitmap import android.graphics.BitmapFactory import android.util.Base64 -import com.google.ai.client.generativeai.internal.api.CountTokensResponse -import com.google.ai.client.generativeai.internal.api.GenerateContentResponse -import com.google.ai.client.generativeai.internal.api.client.GenerationConfig -import com.google.ai.client.generativeai.internal.api.server.BlockReason -import com.google.ai.client.generativeai.internal.api.server.Candidate -import com.google.ai.client.generativeai.internal.api.server.CitationSources -import com.google.ai.client.generativeai.internal.api.server.FinishReason -import com.google.ai.client.generativeai.internal.api.server.HarmProbability -import com.google.ai.client.generativeai.internal.api.server.PromptFeedback -import com.google.ai.client.generativeai.internal.api.server.SafetyRating -import com.google.ai.client.generativeai.internal.api.shared.Blob -import com.google.ai.client.generativeai.internal.api.shared.BlobPart -import com.google.ai.client.generativeai.internal.api.shared.Content -import com.google.ai.client.generativeai.internal.api.shared.HarmBlockThreshold -import com.google.ai.client.generativeai.internal.api.shared.HarmCategory -import com.google.ai.client.generativeai.internal.api.shared.Part -import com.google.ai.client.generativeai.internal.api.shared.SafetySetting -import com.google.ai.client.generativeai.internal.api.shared.TextPart +import com.google.ai.client.generativeai.common.CountTokensResponse +import com.google.ai.client.generativeai.common.GenerateContentResponse +import com.google.ai.client.generativeai.common.RequestOptions +import com.google.ai.client.generativeai.common.client.GenerationConfig +import com.google.ai.client.generativeai.common.server.BlockReason +import com.google.ai.client.generativeai.common.server.Candidate +import com.google.ai.client.generativeai.common.server.CitationSources +import com.google.ai.client.generativeai.common.server.FinishReason +import com.google.ai.client.generativeai.common.server.HarmProbability +import com.google.ai.client.generativeai.common.server.PromptFeedback +import com.google.ai.client.generativeai.common.server.SafetyRating +import com.google.ai.client.generativeai.common.shared.Blob +import com.google.ai.client.generativeai.common.shared.BlobPart +import com.google.ai.client.generativeai.common.shared.Content +import com.google.ai.client.generativeai.common.shared.HarmBlockThreshold +import com.google.ai.client.generativeai.common.shared.HarmCategory +import com.google.ai.client.generativeai.common.shared.Part +import com.google.ai.client.generativeai.common.shared.SafetySetting +import com.google.ai.client.generativeai.common.shared.TextPart import com.google.ai.client.generativeai.type.BlockThreshold import com.google.ai.client.generativeai.type.CitationMetadata import com.google.ai.client.generativeai.type.ImagePart @@ -46,6 +47,9 @@ import java.io.ByteArrayOutputStream private const val BASE_64_FLAGS = Base64.NO_WRAP +internal fun com.google.ai.client.generativeai.type.RequestOptions.toInternal() = + RequestOptions(timeout, apiVersion) + internal fun com.google.ai.client.generativeai.type.Content.toInternal() = Content(this.role, this.parts.map { it.toInternal() }) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Exceptions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Exceptions.kt index 5b451f01..5cfc197e 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Exceptions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Exceptions.kt @@ -17,12 +17,13 @@ package com.google.ai.client.generativeai.type import com.google.ai.client.generativeai.GenerativeModel -import io.ktor.serialization.JsonConvertException +import com.google.ai.client.generativeai.internal.util.toPublic import kotlinx.coroutines.TimeoutCancellationException /** Parent class for any errors that occur from [GenerativeModel]. */ sealed class GoogleGenerativeAIException(message: String, cause: Throwable? = null) : RuntimeException(message, cause) { + companion object { /** @@ -34,12 +35,30 @@ sealed class GoogleGenerativeAIException(message: String, cause: Throwable? = nu fun from(cause: Throwable): GoogleGenerativeAIException = when (cause) { is GoogleGenerativeAIException -> cause - is JsonConvertException, - is kotlinx.serialization.SerializationException -> - SerializationException( - "Something went wrong while trying to deserialize a response from the server.", - cause - ) + is com.google.ai.client.generativeai.common.GoogleGenerativeAIException -> + when (cause) { + is com.google.ai.client.generativeai.common.SerializationException -> + SerializationException(cause.message ?: "", cause.cause) + is com.google.ai.client.generativeai.common.ServerException -> + ServerException(cause.message ?: "", cause.cause) + is com.google.ai.client.generativeai.common.InvalidAPIKeyException -> + InvalidAPIKeyException( + cause.message ?: "", + ) + is com.google.ai.client.generativeai.common.PromptBlockedException -> + PromptBlockedException(cause.response.toPublic(), cause.cause) + is com.google.ai.client.generativeai.common.UnsupportedUserLocationException -> + UnsupportedUserLocationException(cause.cause) + is com.google.ai.client.generativeai.common.InvalidStateException -> + InvalidStateException(cause.message ?: "", cause) + is com.google.ai.client.generativeai.common.ResponseStoppedException -> + ResponseStoppedException(cause.response.toPublic(), cause.cause) + is com.google.ai.client.generativeai.common.RequestTimeoutException -> + RequestTimeoutException(cause.message ?: "", cause.cause) + is com.google.ai.client.generativeai.common.UnknownException -> + UnknownException(cause.message ?: "", cause.cause) + else -> UnknownException(cause.message ?: "", cause) + } is TimeoutCancellationException -> RequestTimeoutException("The request failed to complete in the allotted time.") else -> UnknownException("Something unexpected happened.", cause) @@ -66,6 +85,7 @@ class InvalidAPIKeyException(message: String, cause: Throwable? = null) : * * @property response the full server response for the request. */ +// TODO(rlazo): Add secondary constructor to pass through the message? class PromptBlockedException(val response: GenerateContentResponse, cause: Throwable? = null) : GoogleGenerativeAIException( "Prompt was blocked: ${response.promptFeedback?.blockReason?.name}", @@ -79,6 +99,7 @@ class PromptBlockedException(val response: GenerateContentResponse, cause: Throw * [list of regions](https://ai.google.dev/available_regions#available_regions) (countries and * territories) where the API is available. */ +// TODO(rlazo): Add secondary constructor to pass through the message? class UnsupportedUserLocationException(cause: Throwable? = null) : GoogleGenerativeAIException("User location is not supported for the API use.", cause) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt index cc9669d9..e25561fb 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt @@ -16,7 +16,6 @@ package com.google.ai.client.generativeai.type -import io.ktor.client.plugins.HttpTimeout import kotlin.time.Duration import kotlin.time.DurationUnit import kotlin.time.toDuration @@ -31,10 +30,7 @@ import kotlin.time.toDuration class RequestOptions(val timeout: Duration, val apiVersion: String = "v1") { @JvmOverloads constructor( - timeout: Long? = HttpTimeout.INFINITE_TIMEOUT_MS, + timeout: Long? = Long.MAX_VALUE, apiVersion: String = "v1" - ) : this( - (timeout ?: HttpTimeout.INFINITE_TIMEOUT_MS).toDuration(DurationUnit.MILLISECONDS), - apiVersion - ) + ) : this((timeout ?: Long.MAX_VALUE).toDuration(DurationUnit.MILLISECONDS), apiVersion) } diff --git a/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt b/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt index 15145389..bfc0b24c 100644 --- a/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt +++ b/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt @@ -16,91 +16,119 @@ package com.google.ai.client.generativeai -import com.google.ai.client.generativeai.type.RequestOptions -import com.google.ai.client.generativeai.type.RequestTimeoutException -import com.google.ai.client.generativeai.util.commonTest -import com.google.ai.client.generativeai.util.createGenerativeModel -import com.google.ai.client.generativeai.util.createResponses -import com.google.ai.client.generativeai.util.doBlocking -import com.google.ai.client.generativeai.util.prepareStreamingResponse +import com.google.ai.client.generativeai.common.APIController +import com.google.ai.client.generativeai.common.GenerateContentRequest as GenerateContentRequest_Common +import com.google.ai.client.generativeai.common.GenerateContentResponse as GenerateContentResponse_Common +import com.google.ai.client.generativeai.common.InvalidAPIKeyException as InvalidAPIKeyException_Common +import com.google.ai.client.generativeai.common.UnsupportedUserLocationException as UnsupportedUserLocationException_Common +import com.google.ai.client.generativeai.common.server.Candidate as Candidate_Common +import com.google.ai.client.generativeai.common.shared.Content as Content_Common +import com.google.ai.client.generativeai.common.shared.TextPart as TextPart_Common +import com.google.ai.client.generativeai.type.Candidate +import com.google.ai.client.generativeai.type.Content +import com.google.ai.client.generativeai.type.GenerateContentResponse +import com.google.ai.client.generativeai.type.InvalidAPIKeyException +import com.google.ai.client.generativeai.type.PromptFeedback +import com.google.ai.client.generativeai.type.TextPart +import com.google.ai.client.generativeai.type.UnsupportedUserLocationException import io.kotest.assertions.throwables.shouldThrow -import io.kotest.matchers.shouldBe -import io.kotest.matchers.string.shouldContain -import io.ktor.client.engine.mock.MockEngine -import io.ktor.client.engine.mock.respond -import io.ktor.http.HttpHeaders -import io.ktor.http.HttpStatusCode -import io.ktor.http.headersOf -import io.ktor.utils.io.ByteChannel -import io.ktor.utils.io.close -import io.ktor.utils.io.writeFully -import kotlin.time.Duration.Companion.seconds -import kotlinx.coroutines.withTimeout +import io.kotest.matchers.collections.shouldHaveSize +import io.kotest.matchers.equality.shouldBeEqualToUsingFields +import io.mockk.coEvery +import io.mockk.mockk +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.runBlocking import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.Parameterized internal class GenerativeModelTests { - private val testTimeout = 5.seconds + + private val apiKey: String = "api_key" + private val mockApiController = mockk() @Test - fun `(generateContentStream) emits responses as they come in`() = commonTest { - val response = createResponses("The", " world", " is", " a", " beautiful", " place!") - val bytes = prepareStreamingResponse(response) + fun `generateContent request succeeds`() = doBlocking { + val model = GenerativeModel("gemini-pro-1.0", apiKey, controller = mockApiController) + coEvery { + mockApiController.generateContent( + GenerateContentRequest_Common( + "gemini-pro-1.0", + contents = listOf(Content_Common(parts = listOf(TextPart_Common("Why's the sky blue?")))) + ) + ) + } returns + GenerateContentResponse_Common( + listOf( + Candidate_Common( + content = + Content_Common( + parts = listOf(TextPart_Common("I'm still learning how to answer this question")) + ), + finishReason = null, + safetyRatings = listOf(), + citationMetadata = null + ) + ) + ) + + val expectedResponse = + GenerateContentResponse( + listOf( + Candidate( + Content(parts = listOf(TextPart("I'm still learning how to answer this question"))), + safetyRatings = listOf(), + citationMetadata = listOf(), + finishReason = null + ) + ), + PromptFeedback(null, listOf()) + ) - bytes.forEach { channel.writeFully(it) } - val responses = model.generateContentStream() + val response = model.generateContent("Why's the sky blue?") - withTimeout(testTimeout) { - responses.collect { - it.candidates.isEmpty() shouldBe false - channel.close() - } - } + response.shouldBeEqualToUsingFields(expectedResponse, GenerateContentResponse::text) + response.candidates shouldHaveSize expectedResponse.candidates.size + response.candidates[0].shouldBeEqualToUsingFields( + expectedResponse.candidates[0], + Candidate::finishReason, + Candidate::citationMetadata, + Candidate::safetyRatings + ) } @Test - fun `(generateContent) respects a custom timeout`() = - commonTest(requestOptions = RequestOptions(2.seconds)) { - shouldThrow { - withTimeout(testTimeout) { model.generateContent("d") } - } - } -} + fun `generateContent throws exception`() = doBlocking { + val model = GenerativeModel("gemini-pro-1.0", apiKey, controller = mockApiController) + coEvery { + mockApiController.generateContent( + GenerateContentRequest_Common( + "gemini-pro-1.0", + contents = listOf(Content_Common(parts = listOf(TextPart_Common("Why's the sky blue?")))) + ) + ) + } throws InvalidAPIKeyException_Common("exception message") -@RunWith(Parameterized::class) -internal class ModelNamingTests(private val modelName: String, private val actualName: String) { + shouldThrow { model.generateContent("Why's the sky blue?") } + } @Test - fun `request should include right model name`() = doBlocking { - val channel = ByteChannel(autoFlush = true) - val mockEngine = MockEngine { - respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) - } - prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) } - val model = - createGenerativeModel(modelName, "super_cool_test_key", RequestOptions(), mockEngine) + fun `generateContentStream throws exception`() = doBlocking { + val model = GenerativeModel("gemini-pro-1.0", apiKey, controller = mockApiController) + coEvery { + mockApiController.generateContentStream( + GenerateContentRequest_Common( + "gemini-pro-1.0", + contents = listOf(Content_Common(parts = listOf(TextPart_Common("Why's the sky blue?")))) + ) + ) + } returns flow { throw UnsupportedUserLocationException_Common() } - withTimeout(5.seconds) { - model.generateContentStream().collect { - it.candidates.isEmpty() shouldBe false - channel.close() - } + shouldThrow { + model.generateContentStream("Why's the sky blue?").collect {} } - - mockEngine.requestHistory.first().url.encodedPath shouldContain actualName } +} - companion object { - @JvmStatic - @Parameterized.Parameters - fun data() = - listOf( - arrayOf("gemini-pro", "models/gemini-pro"), - arrayOf("x/gemini-pro", "x/gemini-pro"), - arrayOf("models/gemini-pro", "models/gemini-pro"), - arrayOf("/modelname", "/modelname"), - arrayOf("modifiedNaming/mymodel", "modifiedNaming/mymodel"), - ) - } +internal fun doBlocking(block: suspend CoroutineScope.() -> Unit) { + runBlocking(block = block) } diff --git a/generativeai/src/test/java/com/google/ai/client/generativeai/UnarySnapshotTests.kt b/generativeai/src/test/java/com/google/ai/client/generativeai/UnarySnapshotTests.kt deleted file mode 100644 index 3c15a9e5..00000000 --- a/generativeai/src/test/java/com/google/ai/client/generativeai/UnarySnapshotTests.kt +++ /dev/null @@ -1,167 +0,0 @@ -/* - * Copyright 2023 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.ai.client.generativeai - -import com.google.ai.client.generativeai.type.BlockReason -import com.google.ai.client.generativeai.type.FinishReason -import com.google.ai.client.generativeai.type.HarmCategory -import com.google.ai.client.generativeai.type.InvalidAPIKeyException -import com.google.ai.client.generativeai.type.PromptBlockedException -import com.google.ai.client.generativeai.type.ResponseStoppedException -import com.google.ai.client.generativeai.type.SerializationException -import com.google.ai.client.generativeai.type.ServerException -import com.google.ai.client.generativeai.type.UnsupportedUserLocationException -import com.google.ai.client.generativeai.util.goldenUnaryFile -import io.kotest.assertions.throwables.shouldThrow -import io.kotest.matchers.should -import io.kotest.matchers.shouldBe -import io.ktor.http.HttpStatusCode -import kotlin.time.Duration.Companion.seconds -import kotlinx.coroutines.withTimeout -import org.junit.Test - -internal class UnarySnapshotTests { - private val testTimeout = 5.seconds - - @Test - fun `short reply`() = - goldenUnaryFile("success-basic-reply-short.json") { - withTimeout(testTimeout) { - val response = model.generateContent() - - response.candidates.isEmpty() shouldBe false - response.candidates.first().finishReason shouldBe FinishReason.STOP - response.candidates.first().content.parts.isEmpty() shouldBe false - response.candidates.first().safetyRatings.isEmpty() shouldBe false - } - } - - @Test - fun `long reply`() = - goldenUnaryFile("success-basic-reply-long.json") { - withTimeout(testTimeout) { - val response = model.generateContent() - - response.candidates.isEmpty() shouldBe false - response.candidates.first().finishReason shouldBe FinishReason.STOP - response.candidates.first().content.parts.isEmpty() shouldBe false - response.candidates.first().safetyRatings.isEmpty() shouldBe false - } - } - - @Test - fun `unknown enum`() = - goldenUnaryFile("success-unknown-enum.json") { - withTimeout(testTimeout) { - val response = model.generateContent() - - response.candidates.first { it.safetyRatings.any { it.category == HarmCategory.UNKNOWN } } - } - } - - @Test - fun `prompt blocked for safety`() = - goldenUnaryFile("failure-prompt-blocked-safety.json") { - withTimeout(testTimeout) { - shouldThrow { model.generateContent() } should - { - it.response.promptFeedback?.blockReason shouldBe BlockReason.SAFETY - } - } - } - - @Test - fun `empty content`() = - goldenUnaryFile("failure-empty-content.json") { - withTimeout(testTimeout) { shouldThrow { model.generateContent() } } - } - - @Test - fun `http error`() = - goldenUnaryFile("failure-http-error.json", HttpStatusCode.PreconditionFailed) { - withTimeout(testTimeout) { shouldThrow { model.generateContent() } } - } - - @Test - fun `user location error`() = - goldenUnaryFile("failure-unsupported-user-location.json", HttpStatusCode.PreconditionFailed) { - withTimeout(testTimeout) { - shouldThrow { model.generateContent() } - } - } - - @Test - fun `stopped for safety`() = - goldenUnaryFile("failure-finish-reason-safety.json") { - withTimeout(testTimeout) { - val exception = shouldThrow { model.generateContent() } - exception.response.candidates.first().finishReason shouldBe FinishReason.SAFETY - } - } - - @Test - fun `citation returns correctly`() = - goldenUnaryFile("success-citations.json") { - withTimeout(testTimeout) { - val response = model.generateContent() - - response.candidates.isEmpty() shouldBe false - response.candidates.first().citationMetadata.isNotEmpty() shouldBe true - } - } - - @Test - fun `citation returns correctly when using alternative name`() = - goldenUnaryFile("success-citations-altname.json") { - withTimeout(testTimeout) { - val response = model.generateContent() - - response.candidates.isEmpty() shouldBe false - response.candidates.first().citationMetadata.isNotEmpty() shouldBe true - } - } - - @Test - fun `invalid response`() = - goldenUnaryFile("failure-invalid-response.json") { - withTimeout(testTimeout) { shouldThrow { model.generateContent() } } - } - - @Test - fun `malformed content`() = - goldenUnaryFile("failure-malformed-content.json") { - withTimeout(testTimeout) { shouldThrow { model.generateContent() } } - } - - @Test - fun `invalid api key`() = - goldenUnaryFile("failure-api-key.json", HttpStatusCode.BadRequest) { - withTimeout(testTimeout) { shouldThrow { model.generateContent() } } - } - - @Test - fun `image rejected`() = - goldenUnaryFile("failure-image-rejected.json", HttpStatusCode.BadRequest) { - withTimeout(testTimeout) { shouldThrow { model.generateContent() } } - } - - @Test - fun `unknown model`() = - goldenUnaryFile("failure-unknown-model.json", HttpStatusCode.NotFound) { - withTimeout(testTimeout) { shouldThrow { model.generateContent() } } - } -} diff --git a/generativeai/src/test/java/com/google/ai/client/generativeai/internal/util/ConversionsTest.kt b/generativeai/src/test/java/com/google/ai/client/generativeai/internal/util/ConversionsTest.kt deleted file mode 100644 index 31255fbe..00000000 --- a/generativeai/src/test/java/com/google/ai/client/generativeai/internal/util/ConversionsTest.kt +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.ai.client.generativeai.internal.util - -import com.google.ai.client.generativeai.internal.api.shared.Content -import com.google.ai.client.generativeai.internal.api.shared.TextPart -import com.google.ai.client.generativeai.type.content -import io.kotest.matchers.shouldBe -import org.junit.Test - -class ConversionsTest { - - @Test - fun `test content conversion toInternal (role not mentioned)`() { - val content = content { text("test") }.toInternal() - content.run { - // default role should be a "user" - role shouldBe "user" - - // only one part should be present - parts.size shouldBe 1 - parts[0].run { (this as TextPart).text shouldBe "test" } - } - } - - @Test - fun `test content conversion toInternal (role mentioned)`() { - val content = content(role = "model") { text("test") }.toInternal() - content.run { - // Role should be a "model" - role shouldBe "model" - - // only one part should be present - parts.size shouldBe 1 - parts[0].run { (this as TextPart).text shouldBe "test" } - } - } - - @Test - fun `test content conversion toPublic (role not mentioned)`() { - val content = Content(parts = listOf(TextPart("test"))).toPublic() - content.role shouldBe "user" - } - - @Test - fun `test content conversion toPublic (role mentioned)`() { - val content = Content(role = "model", parts = listOf(TextPart("test"))).toPublic() - content.role shouldBe "model" - } -} diff --git a/settings.gradle.kts b/settings.gradle.kts index 44f363aa..3f4e53af 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -31,4 +31,5 @@ dependencyResolutionManagement { rootProject.name = "generativeai" include(":generativeai") +include(":common") includeBuild("./plugins")