Skip to content

Commit

Permalink
Merge branch 'main' into davidmotson.count_token_full
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmotson authored Jun 4, 2024
2 parents 51c1ef9 + b1803c4 commit 54d52cc
Show file tree
Hide file tree
Showing 33 changed files with 2,147 additions and 108 deletions.
1,414 changes: 1,414 additions & 0 deletions api/common/0.6.0.api

Large diffs are not rendered by default.

509 changes: 509 additions & 0 deletions api/generativeai/0.7.0.api

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ plugins {
id("org.jetbrains.dokka") version "1.8.20" apply false
kotlin("android") version "1.8.22" apply false
kotlin("plugin.serialization") version "1.8.22" apply false
id("com.ncorti.ktfmt.gradle") version "0.16.0" apply false
id("com.ncorti.ktfmt.gradle") version "0.18.0" apply false
id("license-plugin")
id("multi-project-plugin")
}
Expand Down
2 changes: 1 addition & 1 deletion common/gradle.properties
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version=0.5.0
version=0.6.0
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,15 @@ internal constructor(
private val requestOptions: RequestOptions,
httpEngine: HttpClientEngine,
private val apiClient: String,
private val headerProvider: HeaderProvider?
private val headerProvider: HeaderProvider?,
) {

constructor(
key: String,
model: String,
requestOptions: RequestOptions,
apiClient: String,
headerProvider: HeaderProvider? = null
headerProvider: HeaderProvider? = null,
) : this(key, model, requestOptions, OkHttp.create(), apiClient, headerProvider)

private val model = fullModelName(model)
Expand Down Expand Up @@ -223,12 +223,13 @@ private fun fullModelName(name: String): String = name.takeIf { it.contains("/")
private suspend fun validateResponse(response: HttpResponse) {
if (response.status == HttpStatusCode.OK) return
val text = response.bodyAsText()
val message =
val error =
try {
JSON.decodeFromString<GRpcErrorResponse>(text).error.message
JSON.decodeFromString<GRpcErrorResponse>(text).error
} catch (e: Throwable) {
"Unexpected Response:\n$text"
throw ServerException("Unexpected Response:\n$text $e")
}
val message = error.message
if (message.contains("API key not valid")) {
throw InvalidAPIKeyException(message)
}
Expand All @@ -239,6 +240,9 @@ private suspend fun validateResponse(response: HttpResponse) {
if (message.contains("quota")) {
throw QuotaExceededException(message)
}
if (error.details?.any { "SERVICE_DISABLED" == it.reason } == true) {
throw ServiceDisabledException(message)
}
throw ServerException(message)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ class RequestTimeoutException(message: String, cause: Throwable? = null) :
class QuotaExceededException(message: String, cause: Throwable? = null) :
GoogleGenerativeAIException(message, cause)

/** The service is not enabled for this project. Visit the Firebase Console to enable it. */
class ServiceDisabledException(message: String, cause: Throwable? = null) :
GoogleGenerativeAIException(message, cause)

/** Catch all case for exceptions not explicitly expected. */
class UnknownException(message: String, cause: Throwable? = null) :
GoogleGenerativeAIException(message, cause)
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ sealed interface Response
data class GenerateContentResponse(
val candidates: List<Candidate>? = null,
val promptFeedback: PromptFeedback? = null,
val usageMetadata: UsageMetadata? = null
val usageMetadata: UsageMetadata? = null,
) : Response

@Serializable
Expand All @@ -40,5 +40,5 @@ data class CountTokensResponse(val totalTokens: Int, val totalBillableCharacters
data class UsageMetadata(
val promptTokenCount: Int? = null,
val candidatesTokenCount: Int? = null,
val totalTokenCount: Int? = null
val totalTokenCount: Int? = null,
)
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ data class GenerationConfig(
@SerialName("response_mime_type") val responseMimeType: String? = null,
@SerialName("presence_penalty") val presencePenalty: Float? = null,
@SerialName("frequency_penalty") val frequencyPenalty: Float? = null,
@SerialName("response_schema") val responseSchema: Schema? = null,
)

@Serializable data class Tool(val functionDeclarations: List<FunctionDeclaration>)
Expand All @@ -51,17 +52,14 @@ data class FunctionCallingConfig(val mode: Mode) {
}

@Serializable
data class FunctionDeclaration(
val name: String,
val description: String,
val parameters: Schema,
)
data class FunctionDeclaration(val name: String, val description: String, val parameters: Schema)

@Serializable
data class Schema(
val type: String,
val description: String? = null,
val format: String? = null,
val nullable: Boolean? = false,
val enum: List<String>? = null,
val properties: Map<String, Schema>? = null,
val required: List<String>? = null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ data class CitationSources(
val startIndex: Int = 0,
val endIndex: Int,
val uri: String,
val license: String? = null
val license: String? = null,
)

@Serializable
Expand Down Expand Up @@ -138,7 +138,6 @@ enum class FinishReason {
}

@Serializable
data class GRpcError(
val code: Int,
val message: String,
)
data class GRpcError(val code: Int, val message: String, val details: List<GRpcErrorDetails>)

@Serializable data class GRpcErrorDetails(val reason: String? = null)
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,17 @@ data class Content(@EncodeDefault val role: String? = "user", val parts: List<Pa

@Serializable data class FunctionResponse(val name: String, val response: JsonObject)

@Serializable data class FunctionCall(val name: String, val args: Map<String, String>)
@Serializable data class FunctionCall(val name: String, val args: Map<String, String?>)

@Serializable data class FileDataPart(@SerialName("file_data") val fileData: FileData) : Part

@Serializable
data class FileData(
@SerialName("mime_type") val mimeType: String,
@SerialName("file_uri") val fileUri: String
@SerialName("file_uri") val fileUri: String,
)

@Serializable
data class Blob(
@SerialName("mime_type") val mimeType: String,
val data: Base64,
)
@Serializable data class Blob(@SerialName("mime_type") val mimeType: String, val data: Base64)

@Serializable
data class SafetySetting(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class FirstOrdinalSerializer<T : Enum<T>>(private val enumClass: KClass<T>) : KS
|GitHub to bring it to our attention:
|https://github.com/google/google-ai-android
"""
.trimMargin()
.trimMargin(),
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ internal class RequestFormatTests {
RequestOptions(),
mockEngine,
"genai-android/${BuildConfig.VERSION_NAME}",
null
null,
)

withTimeout(5.seconds) {
Expand All @@ -121,7 +121,7 @@ internal class RequestFormatTests {
RequestOptions(endpoint = "https://my.custom.endpoint"),
mockEngine,
TEST_CLIENT_ID,
null
null,
)

withTimeout(5.seconds) {
Expand All @@ -148,7 +148,7 @@ internal class RequestFormatTests {
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
null
null,
)

withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) }
Expand All @@ -171,7 +171,7 @@ internal class RequestFormatTests {
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
null
null,
)

withTimeout(5.seconds) {
Expand All @@ -184,7 +184,7 @@ internal class RequestFormatTests {
ToolConfig(
functionCallingConfig =
FunctionCallingConfig(mode = FunctionCallingConfig.Mode.AUTO)
)
),
)
)
.collect { channel.close() }
Expand Down Expand Up @@ -218,7 +218,7 @@ internal class RequestFormatTests {
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
testHeaderProvider
testHeaderProvider,
)

withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) }
Expand Down Expand Up @@ -252,7 +252,7 @@ internal class RequestFormatTests {
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
testHeaderProvider
testHeaderProvider,
)

withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) }
Expand All @@ -278,7 +278,7 @@ internal class ModelNamingTests(private val modelName: String, private val actua
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
null
null,
)

withTimeout(5.seconds) {
Expand Down Expand Up @@ -308,7 +308,7 @@ internal class ModelNamingTests(private val modelName: String, private val actua
fun textGenerateContentRequest(prompt: String) =
GenerateContentRequest(
model = "unused",
contents = listOf(Content(parts = listOf(TextPart(prompt))))
contents = listOf(Content(parts = listOf(TextPart(prompt)))),
)

fun textCountTokenRequest(prompt: String) = CountTokensRequest(textGenerateContentRequest(prompt))
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import com.google.ai.client.generativeai.common.server.BlockReason
import com.google.ai.client.generativeai.common.server.FinishReason
import com.google.ai.client.generativeai.common.server.HarmProbability
import com.google.ai.client.generativeai.common.server.HarmSeverity
import com.google.ai.client.generativeai.common.shared.FunctionCallPart
import com.google.ai.client.generativeai.common.shared.HarmCategory
import com.google.ai.client.generativeai.common.shared.TextPart
import com.google.ai.client.generativeai.common.util.goldenUnaryFile
Expand Down Expand Up @@ -291,4 +292,25 @@ internal class UnarySnapshotTests {
}
}
}

@Test
fun `service disabled`() =
goldenUnaryFile("failure-service-disabled.json", HttpStatusCode.Forbidden) {
withTimeout(testTimeout) {
shouldThrow<ServiceDisabledException> {
apiController.generateContent(textGenerateContentRequest("prompt"))
}
}
}

@Test
fun `function call contains null param`() =
goldenUnaryFile("success-function-call-null.json") {
withTimeout(testTimeout) {
val response = apiController.generateContent(textGenerateContentRequest("prompt"))
val callPart = (response.candidates!!.first().content!!.parts.first() as FunctionCallPart)

callPart.functionCall.args["season"] shouldBe null
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ internal typealias CommonTest = suspend CommonTestScope.() -> Unit
internal fun commonTest(
status: HttpStatusCode = HttpStatusCode.OK,
requestOptions: RequestOptions = RequestOptions(),
block: CommonTest
block: CommonTest,
) = doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
Expand All @@ -113,7 +113,7 @@ internal fun commonTest(
requestOptions,
mockEngine,
TEST_CLIENT_ID,
null
null,
)
CommonTestScope(channel, apiController).block()
}
Expand All @@ -132,7 +132,7 @@ internal fun commonTest(
internal fun goldenStreamingFile(
name: String,
httpStatusCode: HttpStatusCode = HttpStatusCode.OK,
block: CommonTest
block: CommonTest,
) = doBlocking {
val goldenFile = loadGoldenFile("streaming/$name")
val messages = goldenFile.readLines().filter { it.isNotBlank() }
Expand Down Expand Up @@ -162,7 +162,7 @@ internal fun goldenStreamingFile(
internal fun goldenUnaryFile(
name: String,
httpStatusCode: HttpStatusCode = HttpStatusCode.OK,
block: CommonTest
block: CommonTest,
) =
commonTest(httpStatusCode) {
val goldenFile = loadGoldenFile("unary/$name")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"error": {
"code": 403,
"message": "Firebase ML API has not been used in project 12345 before or it is disabled. Enable it by visiting https://console.developers.google.com/apis/api/firebaseml.googleapis.com/overview?project=12345 then retry. If you enabled this API recently, wait a few minutes for the action to propagate to our systems and retry.",
"status": "PERMISSION_DENIED",
"details": [
{
"@type": "type.googleapis.com/google.rpc.Help",
"links": [
{
"description": "Google developers console API activation",
"url": "https://console.developers.google.com/apis/api/firebaseml.googleapis.com/overview?project=12345"
}
]
},
{
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
"reason": "SERVICE_DISABLED",
"domain": "googleapis.com",
"metadata": {
"service": "firebaseml.googleapis.com",
"consumer": "projects/12345"
}
}
]
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
{
"candidates": [
{
"content": {
"parts": [
{
"functionCall": {
"name": "functionName",
"args": {
"original_title": "String",
"season": null
}
}
}
],
"role": "model"
},
"finishReason": "STOP",
"index": 0,
"safetyRatings": [
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE"
}
]
}
],
"usageMetadata": {
"promptTokenCount": 774,
"candidatesTokenCount": 4176,
"totalTokenCount": 4950
}
}
2 changes: 1 addition & 1 deletion generativeai-android-sample/app/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,5 @@ dependencies {
debugImplementation("androidx.compose.ui:ui-tooling")
debugImplementation("androidx.compose.ui:ui-test-manifest")

implementation("com.google.ai.client.generativeai:generativeai:0.6.0")
implementation("com.google.ai.client.generativeai:generativeai:0.7.0")
}
Loading

0 comments on commit 54d52cc

Please sign in to comment.