Skip to content

Commit

Permalink
Merge branch 'main' into rl.api.requestoptions
Browse files Browse the repository at this point in the history
  • Loading branch information
rlazo committed Feb 28, 2024
2 parents 69ab2ab + be02307 commit 70b41ca
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 8 deletions.
1 change: 1 addition & 0 deletions .changes/branch-badge-dock-advertisement.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"MAJOR","changes":["Make \"user\" the default role in all requests"]}
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
suspend fun sendMessage(prompt: String): GenerateContentResponse {
val content = content("user") { text(prompt) }
val content = content { text(prompt) }
return sendMessage(content)
}

Expand All @@ -84,7 +84,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
suspend fun sendMessage(prompt: Bitmap): GenerateContentResponse {
val content = content("user") { image(prompt) }
val content = content { image(prompt) }
return sendMessage(content)
}

Expand Down Expand Up @@ -150,7 +150,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
fun sendMessageStream(prompt: String): Flow<GenerateContentResponse> {
val content = content("user") { text(prompt) }
val content = content { text(prompt) }
return sendMessageStream(content)
}

Expand All @@ -162,7 +162,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
fun sendMessageStream(prompt: Bitmap): Flow<GenerateContentResponse> {
val content = content("user") { image(prompt) }
val content = content { image(prompt) }
return sendMessageStream(content)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ package com.google.ai.client.generativeai.internal.api
import com.google.ai.client.generativeai.BuildConfig
import com.google.ai.client.generativeai.internal.util.decodeToFlow
import com.google.ai.client.generativeai.type.RequestOptions
import com.google.ai.client.generativeai.type.InvalidAPIKeyException
import com.google.ai.client.generativeai.type.ServerException
import com.google.ai.client.generativeai.type.UnsupportedUserLocationException
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.engine.HttpClientEngine
Expand Down Expand Up @@ -176,7 +178,13 @@ private suspend fun validateResponse(response: HttpResponse) {
} catch (e: Throwable) {
"Unexpected Response:\n$text"
}

if (message.contains("API key not valid")) {
throw InvalidAPIKeyException(message)
}
// TODO (b/325117891): Use a better method than string matching.
if (message == "User location is not supported for the API use.") {
throw UnsupportedUserLocationException()
}
throw ServerException(message)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ internal enum class HarmCategory {

typealias Base64 = String

@Serializable internal data class Content(val role: String? = null, val parts: List<Part>)
@Serializable internal data class Content(val role: String? = "user", val parts: List<Part>)

@Serializable(PartSerializer::class) internal sealed interface Part

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class SerializationException(message: String, cause: Throwable? = null) :
class ServerException(message: String, cause: Throwable? = null) :
GoogleGenerativeAIException(message, cause)

/** The server responded that the API Key is no valid. */
class InvalidAPIKeyException(message: String, cause: Throwable? = null) :
GoogleGenerativeAIException(message, cause)

/**
* A request was blocked for some reason.
*
Expand All @@ -68,6 +72,16 @@ class PromptBlockedException(val response: GenerateContentResponse, cause: Throw
cause
)

/**
* The user's location (region) is not supported by the API.
*
* See the Google documentation for a
* [list of regions](https://ai.google.dev/available_regions#available_regions) (countries and
* territories) where the API is available.
*/
class UnsupportedUserLocationException(cause: Throwable? = null) :
GoogleGenerativeAIException("User location is not supported for the API use.", cause)

/**
* Some form of state occurred that shouldn't have.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.google.ai.client.generativeai
import com.google.ai.client.generativeai.type.BlockReason
import com.google.ai.client.generativeai.type.FinishReason
import com.google.ai.client.generativeai.type.HarmCategory
import com.google.ai.client.generativeai.type.InvalidAPIKeyException
import com.google.ai.client.generativeai.type.PromptBlockedException
import com.google.ai.client.generativeai.type.ResponseStoppedException
import com.google.ai.client.generativeai.type.SerializationException
Expand Down Expand Up @@ -173,6 +174,6 @@ internal class StreamingSnapshotTests {
goldenStreamingFile("failure-api-key.txt", HttpStatusCode.BadRequest) {
val responses = model.generateContentStream()

withTimeout(testTimeout) { shouldThrow<ServerException> { responses.collect() } }
withTimeout(testTimeout) { shouldThrow<InvalidAPIKeyException> { responses.collect() } }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ package com.google.ai.client.generativeai
import com.google.ai.client.generativeai.type.BlockReason
import com.google.ai.client.generativeai.type.FinishReason
import com.google.ai.client.generativeai.type.HarmCategory
import com.google.ai.client.generativeai.type.InvalidAPIKeyException
import com.google.ai.client.generativeai.type.PromptBlockedException
import com.google.ai.client.generativeai.type.ResponseStoppedException
import com.google.ai.client.generativeai.type.SerializationException
import com.google.ai.client.generativeai.type.ServerException
import com.google.ai.client.generativeai.type.UnsupportedUserLocationException
import com.google.ai.client.generativeai.util.goldenUnaryFile
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.should
Expand Down Expand Up @@ -94,6 +96,14 @@ internal class UnarySnapshotTests {
withTimeout(testTimeout) { shouldThrow<ServerException> { model.generateContent() } }
}

@Test
fun `user location error`() =
goldenUnaryFile("failure-unsupported-user-location.json", HttpStatusCode.PreconditionFailed) {
withTimeout(testTimeout) {
shouldThrow<UnsupportedUserLocationException> { model.generateContent() }
}
}

@Test
fun `stopped for safety`() =
goldenUnaryFile("failure-finish-reason-safety.json") {
Expand Down Expand Up @@ -129,7 +139,7 @@ internal class UnarySnapshotTests {
@Test
fun `invalid api key`() =
goldenUnaryFile("failure-api-key.json", HttpStatusCode.BadRequest) {
withTimeout(testTimeout) { shouldThrow<ServerException> { model.generateContent() } }
withTimeout(testTimeout) { shouldThrow<InvalidAPIKeyException> { model.generateContent() } }
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"error": {
"code": 400,
"message": "User location is not supported for the API use.",
"status": "FAILED_PRECONDITION",
"details": [
{
"@type": "type.googleapis.com/google.rpc.DebugInfo",
"detail": "[ORIGINAL ERROR] generic::failed_precondition: User location is not supported for the API use. [google.rpc.error_details_ext] { message: \"User location is not supported for the API use.\" }"
}
]
}
}

0 comments on commit 70b41ca

Please sign in to comment.