Skip to content

Commit

Permalink
Common & genai sdk split (#88)
Browse files Browse the repository at this point in the history
  • Loading branch information
rlazo authored Mar 20, 2024
1 parent bd5f554 commit 39396f0
Show file tree
Hide file tree
Showing 61 changed files with 1,041 additions and 525 deletions.
1 change: 1 addition & 0 deletions common/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
/build
126 changes: 126 additions & 0 deletions common/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

plugins {
id("com.android.library")
id("maven-publish")
id("com.ncorti.ktfmt.gradle")
id("changelog-plugin")
id("release-plugin")
kotlin("android")
kotlin("plugin.serialization")
}

ktfmt {
googleStyle()
}

android {
namespace = "com.google.ai.client.generativeai.common"
compileSdk = 34

buildFeatures.buildConfig = true

defaultConfig {
minSdk = 21

testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
consumerProguardFiles("consumer-rules.pro")

buildConfigField("String", "VERSION_NAME", "\"${project.version.toString()}\"")
}

publishing {
singleVariant("release") {
withSourcesJar()
}
}

buildTypes {
release {
isMinifyEnabled = false
proguardFiles(
getDefaultProguardFile("proguard-android-optimize.txt"),
"proguard-rules.pro"
)
}
}
compileOptions {
sourceCompatibility = JavaVersion.VERSION_17
targetCompatibility = JavaVersion.VERSION_17
}
kotlinOptions {
jvmTarget = "17"
}

testOptions {
unitTests.isReturnDefaultValues = true
}
}

dependencies {
val ktorVersion = "2.3.2"

implementation("io.ktor:ktor-client-okhttp:$ktorVersion")
implementation("io.ktor:ktor-client-core:$ktorVersion")
implementation("io.ktor:ktor-client-content-negotiation:$ktorVersion")
implementation("io.ktor:ktor-serialization-kotlinx-json:$ktorVersion")
implementation("io.ktor:ktor-client-logging:$ktorVersion")

implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1")
implementation("androidx.core:core-ktx:1.12.0")
implementation("org.slf4j:slf4j-nop:2.0.9")
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.7.3")
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-reactive:1.7.3")
implementation("org.reactivestreams:reactive-streams:1.0.3")

implementation("com.google.guava:listenablefuture:1.0")
implementation("androidx.concurrent:concurrent-futures:1.2.0-alpha02")
implementation("androidx.concurrent:concurrent-futures-ktx:1.2.0-alpha02")
testImplementation("junit:junit:4.13.2")
testImplementation("io.kotest:kotest-assertions-core:4.0.7")
testImplementation("io.kotest:kotest-assertions-jvm:4.0.7")
testImplementation("io.kotest:kotest-assertions-json:4.0.7")
testImplementation("io.ktor:ktor-client-mock:$ktorVersion")
androidTestImplementation("androidx.test.ext:junit:1.1.5")
androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")
}

publishing {
publications {
register<MavenPublication>("release") {
groupId = "com.google.ai.client.generativeai"
artifactId = "common"
version = project.version.toString()
pom {
licenses {
license {
name = "The Apache License, Version 2.0"
url = "http://www.apache.org/licenses/LICENSE-2.0.txt"
}
}
}
afterEvaluate {
from(components["release"])
}
}
}
repositories {
maven {
url = uri("${projectDir}/m2")
}
}
}
21 changes: 21 additions & 0 deletions common/consumer-rules.pro
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Add project specific ProGuard rules here.
# You can control the set of applied configuration files using the
# proguardFiles setting in build.gradle.
#
# For more details, see
# http://developer.android.com/guide/developing/tools/proguard.html

# If your project uses WebView with JS, uncomment the following
# and specify the fully qualified class name to the JavaScript interface
# class:
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
# public *;
#}

# Uncomment this to preserve the line number information for
# debugging stack traces.
#-keepattributes SourceFile,LineNumberTable

# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile
1 change: 1 addition & 0 deletions common/gradle.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
version=0.1.0
21 changes: 21 additions & 0 deletions common/proguard-rules.pro
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Add project specific ProGuard rules here.
# You can control the set of applied configuration files using the
# proguardFiles setting in build.gradle.
#
# For more details, see
# http://developer.android.com/guide/developing/tools/proguard.html

# If your project uses WebView with JS, uncomment the following
# and specify the fully qualified class name to the JavaScript interface
# class:
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
# public *;
#}

# Uncomment this to preserve the line number information for
# debugging stack traces.
#-keepattributes SourceFile,LineNumberTable

# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile
18 changes: 18 additions & 0 deletions common/src/main/AndroidManifest.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
<?xml version="1.0" encoding="utf-8"?>
<!-- Copyright 2024 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<manifest xmlns:android="http://schemas.android.com/apk/res/android">
<uses-permission android:name="android.permission.INTERNET"/>
</manifest>
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 Google LLC
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -14,14 +14,10 @@
* limitations under the License.
*/

package com.google.ai.client.generativeai.internal.api
package com.google.ai.client.generativeai.common

import com.google.ai.client.generativeai.BuildConfig
import com.google.ai.client.generativeai.internal.util.decodeToFlow
import com.google.ai.client.generativeai.type.InvalidAPIKeyException
import com.google.ai.client.generativeai.type.RequestOptions
import com.google.ai.client.generativeai.type.ServerException
import com.google.ai.client.generativeai.type.UnsupportedUserLocationException
import com.google.ai.client.generativeai.common.server.FinishReason
import com.google.ai.client.generativeai.common.util.decodeToFlow
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.engine.HttpClientEngine
Expand All @@ -42,14 +38,16 @@ import io.ktor.http.contentType
import io.ktor.serialization.kotlinx.json.json
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.channelFlow
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.timeout
import kotlinx.coroutines.launch
import kotlinx.serialization.json.Json

internal const val DOMAIN = "https://generativelanguage.googleapis.com"
const val DOMAIN = "https://generativelanguage.googleapis.com"

internal val JSON = Json {
val JSON = Json {
ignoreUnknownKeys = true
prettyPrint = false
}
Expand All @@ -63,15 +61,21 @@ internal val JSON = Json {
* Exposed primarily for DI in tests.
* @property key The API key used for authentication.
* @property model The model to use for generation.
* @property apiVersion the endpoint version to communicate with.
* @property timeout the maximum amount of time for a request to take in the initial exchange.
*/
internal class APIController(
class APIController
internal constructor(
private val key: String,
model: String,
private val requestOptions: RequestOptions,
httpEngine: HttpClientEngine = OkHttp.create(),
httpEngine: HttpClientEngine
) {

constructor(
key: String,
model: String,
requestOptions: RequestOptions
) : this(key, model, requestOptions, OkHttp.create())

private val model = fullModelName(model)

private val client =
Expand All @@ -84,28 +88,39 @@ internal class APIController(
}

suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse =
client
.post("$DOMAIN/${requestOptions.apiVersion}/$model:generateContent") {
applyCommonConfiguration(request)
}
.also { validateResponse(it) }
.body()

fun generateContentStream(request: GenerateContentRequest): Flow<GenerateContentResponse> {
return client.postStream<GenerateContentResponse>(
"$DOMAIN/${requestOptions.apiVersion}/$model:streamGenerateContent?alt=sse"
) {
applyCommonConfiguration(request)
try {
client
.post("$DOMAIN/${requestOptions.apiVersion}/$model:generateContent") {
applyCommonConfiguration(request)
}
.also { validateResponse(it) }
.body<GenerateContentResponse>()
.validate()
} catch (e: Throwable) {
throw GoogleGenerativeAIException.from(e)
}
}

suspend fun countTokens(request: CountTokensRequest): CountTokensResponse =
fun generateContentStream(request: GenerateContentRequest): Flow<GenerateContentResponse> =
client
.post("$DOMAIN/${requestOptions.apiVersion}/$model:countTokens") {
.postStream<GenerateContentResponse>(
"$DOMAIN/${requestOptions.apiVersion}/$model:streamGenerateContent?alt=sse"
) {
applyCommonConfiguration(request)
}
.also { validateResponse(it) }
.body()
.map { it.validate() }
.catch { throw GoogleGenerativeAIException.from(it) }

suspend fun countTokens(request: CountTokensRequest): CountTokensResponse =
try {
client
.post("$DOMAIN/${requestOptions.apiVersion}/$model:countTokens") {
applyCommonConfiguration(request)
}
.also { validateResponse(it) }
.body()
} catch (e: Throwable) {
throw GoogleGenerativeAIException.from(e)
}

private fun HttpRequestBuilder.applyCommonConfiguration(request: Request) {
when (request) {
Expand Down Expand Up @@ -170,21 +185,31 @@ private inline fun <reified R : Response> HttpClient.postStream(
}

private suspend fun validateResponse(response: HttpResponse) {
if (response.status != HttpStatusCode.OK) {
val text = response.bodyAsText()
val message =
try {
JSON.decodeFromString<GRpcErrorResponse>(text).error.message
} catch (e: Throwable) {
"Unexpected Response:\n$text"
}
if (message.contains("API key not valid")) {
throw InvalidAPIKeyException(message)
}
// TODO (b/325117891): Use a better method than string matching.
if (message == "User location is not supported for the API use.") {
throw UnsupportedUserLocationException()
if (response.status == HttpStatusCode.OK) return
val text = response.bodyAsText()
val message =
try {
JSON.decodeFromString<GRpcErrorResponse>(text).error.message
} catch (e: Throwable) {
"Unexpected Response:\n$text"
}
throw ServerException(message)
if (message.contains("API key not valid")) {
throw InvalidAPIKeyException(message)
}
// TODO (b/325117891): Use a better method than string matching.
if (message == "User location is not supported for the API use.") {
throw UnsupportedUserLocationException()
}
throw ServerException(message)
}

private fun GenerateContentResponse.validate() = apply {
if ((candidates?.isEmpty() != false) && promptFeedback == null) {
throw SerializationException("Error deserializing response, found no valid fields")
}
promptFeedback?.blockReason?.let { throw PromptBlockedException(this) }
candidates
?.mapNotNull { it.finishReason }
?.firstOrNull { it != FinishReason.STOP }
?.let { throw ResponseStoppedException(this) }
}
Loading

0 comments on commit 39396f0

Please sign in to comment.