Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add constrained decoding support #99

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package com.google.ai.client.generativeai.common

import com.google.ai.client.generativeai.common.client.GenerationConfig
import com.google.ai.client.generativeai.common.client.Tool
import com.google.ai.client.generativeai.common.client.ToolConfig
import com.google.ai.client.generativeai.common.shared.Content
import com.google.ai.client.generativeai.common.shared.SafetySetting
import kotlinx.serialization.SerialName
Expand All @@ -30,6 +32,8 @@ data class GenerateContentRequest(
val contents: List<Content>,
@SerialName("safety_settings") val safetySettings: List<SafetySetting>? = null,
@SerialName("generation_config") val generationConfig: GenerationConfig? = null,
val tools: List<Tool>? = null,
@SerialName("tool_config") var toolConfig: ToolConfig? = null,
) : Request

@Serializable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,24 @@ import kotlin.time.toDuration
* @property timeout the maximum amount of time for a request to take, from the first request to
* first response.
* @property apiVersion the api endpoint to call.
* @property disableAutoFunction if true, auto functions will not be automatically executed
*/
class RequestOptions(
val timeout: Duration,
val apiVersion: String = "v1",
val endpoint: String = "https://generativelanguage.googleapis.com"
val disableAutoFunction: Boolean = false,
val endpoint: String = "https://generativelanguage.googleapis.com",
) {
@JvmOverloads
constructor(
timeout: Long? = HttpTimeout.INFINITE_TIMEOUT_MS,
apiVersion: String = "v1",
endpoint: String = "https://generativelanguage.googleapis.com"
disableAutoFunction: Boolean = false,
endpoint: String = "https://generativelanguage.googleapis.com",
) : this(
(timeout ?: HttpTimeout.INFINITE_TIMEOUT_MS).toDuration(DurationUnit.MILLISECONDS),
apiVersion,
endpoint
disableAutoFunction,
endpoint,
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,41 @@ data class GenerationConfig(
@SerialName("top_k") val topK: Int?,
@SerialName("candidate_count") val candidateCount: Int?,
@SerialName("max_output_tokens") val maxOutputTokens: Int?,
@SerialName("stop_sequences") val stopSequences: List<String>?
@SerialName("stop_sequences") val stopSequences: List<String>?,
)

@Serializable data class Tool(val functionDeclarations: List<FunctionDeclaration>)

@Serializable
data class ToolConfig(
@SerialName("function_calling_config") val functionCallingConfig: FunctionCallingConfig
)

@Serializable
data class FunctionCallingConfig(val mode: Mode) {
@Serializable
enum class Mode {
@SerialName("MODE_UNSPECIFIED") UNSPECIFIED,
AUTO,
ANY,
NONE
}
}

@Serializable
data class FunctionDeclaration(
val name: String,
val description: String,
val parameters: FunctionParameterProperties,
)

@Serializable
data class FunctionParameterProperties(
val type: String,
val description: String? = null,
val format: String? = null,
val enum: List<String>? = null,
val properties: Map<String, FunctionParameterProperties>? = null,
val required: List<String>? = null,
val items: FunctionParameterProperties? = null,
)
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import kotlinx.serialization.Serializable
import kotlinx.serialization.SerializationException
import kotlinx.serialization.json.JsonContentPolymorphicSerializer
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.jsonObject

object HarmCategorySerializer :
Expand All @@ -52,11 +53,15 @@ data class Content(@EncodeDefault val role: String? = "user", val parts: List<Pa

@Serializable data class BlobPart(@SerialName("inline_data") val inlineData: Blob) : Part

@Serializable
data class Blob(
@SerialName("mime_type") val mimeType: String,
val data: Base64,
)
@Serializable data class FunctionCallPart(val functionCall: FunctionCall) : Part

@Serializable data class FunctionResponsePart(val functionResponse: FunctionResponse) : Part

@Serializable data class Blob(@SerialName("mime_type") val mimeType: String, val data: Base64)

@Serializable data class FunctionResponse(val name: String, val response: JsonObject)

@Serializable data class FunctionCall(val name: String, val args: Map<String, String>)

@Serializable
data class SafetySetting(val category: HarmCategory, val threshold: HarmBlockThreshold)
Expand All @@ -76,6 +81,8 @@ object PartSerializer : JsonContentPolymorphicSerializer<Part>(Part::class) {
return when {
"text" in jsonObject -> TextPart.serializer()
"inlineData" in jsonObject -> BlobPart.serializer()
"functionCall" in jsonObject -> FunctionCallPart.serializer()
"functionResponse" in jsonObject -> FunctionResponsePart.serializer()
else -> throw SerializationException("Unknown Part type")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,21 @@

package com.google.ai.client.generativeai.common

import com.google.ai.client.generativeai.common.client.FunctionCallingConfig
import com.google.ai.client.generativeai.common.client.ToolConfig
import com.google.ai.client.generativeai.common.shared.Content
import com.google.ai.client.generativeai.common.shared.TextPart
import com.google.ai.client.generativeai.common.util.commonTest
import com.google.ai.client.generativeai.common.util.createResponses
import com.google.ai.client.generativeai.common.util.doBlocking
import com.google.ai.client.generativeai.common.util.prepareStreamingResponse
import io.kotest.assertions.json.shouldContainJsonKey
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.shouldBe
import io.kotest.matchers.string.shouldContain
import io.ktor.client.engine.mock.MockEngine
import io.ktor.client.engine.mock.respond
import io.ktor.content.TextContent
import io.ktor.http.HttpHeaders
import io.ktor.http.HttpStatusCode
import io.ktor.http.headersOf
Expand Down Expand Up @@ -114,6 +118,38 @@ internal class EndpointTests {

mockEngine.requestHistory.first().url.host shouldBe "my.custom.endpoint"
}

@Test
fun `ToolConfig serialization is correct`() = doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }

val controller =
APIController("super_cool_test_key", "gemini-pro-1.0", RequestOptions(), mockEngine)

withTimeout(5.seconds) {
controller
.generateContentStream(
GenerateContentRequest(
model = "unused",
contents = listOf(Content(parts = listOf(TextPart("Arbitrary")))),
toolConfig =
ToolConfig(
functionCallingConfig =
FunctionCallingConfig(mode = FunctionCallingConfig.Mode.AUTO)
)
)
)
.collect { channel.close() }
}

val requestBodyAsText = (mockEngine.requestHistory.first().body as TextContent).text

requestBodyAsText shouldContainJsonKey "tool_config.function_calling_config.mode"
}
}

@RunWith(Parameterized::class)
Expand Down
1 change: 1 addition & 0 deletions generativeai/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ dependencies {
implementation("org.slf4j:slf4j-nop:2.0.9")
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.7.3")
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-reactive:1.7.3")
implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1")
implementation("org.reactivestreams:reactive-streams:1.0.3")

implementation("com.google.guava:listenablefuture:1.0")
Expand Down
133 changes: 97 additions & 36 deletions generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package com.google.ai.client.generativeai
import android.graphics.Bitmap
import com.google.ai.client.generativeai.type.BlobPart
import com.google.ai.client.generativeai.type.Content
import com.google.ai.client.generativeai.type.FunctionCallPart
import com.google.ai.client.generativeai.type.FunctionResponsePart
import com.google.ai.client.generativeai.type.GenerateContentResponse
import com.google.ai.client.generativeai.type.ImagePart
import com.google.ai.client.generativeai.type.InvalidStateException
Expand All @@ -27,8 +29,9 @@ import com.google.ai.client.generativeai.type.content
import java.util.LinkedList
import java.util.concurrent.Semaphore
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.FlowCollector
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.flow.transform

/**
* Representation of a back and forth interaction with a model.
Expand All @@ -45,6 +48,10 @@ import kotlinx.coroutines.flow.onEach
class Chat(private val model: GenerativeModel, val history: MutableList<Content> = ArrayList()) {
private var lock = Semaphore(1)

companion object {
private val VALID_ROLES = listOf("user", "function")
}

/**
* Generates a response from the backend with the provided [Content], and any previous ones
* sent/returned from this chat.
Expand All @@ -53,13 +60,27 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
* @throws InvalidStateException if the prompt is not coming from the 'user' role
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
suspend fun sendMessage(prompt: Content): GenerateContentResponse {
prompt.assertComesFromUser()
suspend fun sendMessage(inputPrompt: Content): GenerateContentResponse {
inputPrompt.assertComesFromUser()
attemptLock()
var response: GenerateContentResponse
var prompt = inputPrompt
val tempHistory = LinkedList<Content>()
try {
val response = model.generateContent(*history.toTypedArray(), prompt)
history.add(prompt)
history.add(response.candidates.first().content)
while (true) {
response =
model.generateContent(*history.toTypedArray(), *tempHistory.toTypedArray(), prompt)
val responsePart = response.candidates.first().content.parts.first()

tempHistory.add(prompt)
tempHistory.add(response.candidates.first().content)
if (responsePart !is FunctionCallPart) break
if (model.requestOptions.disableAutoFunction) break

val output = model.executeFunction(responsePart)
prompt = Content("function", listOf(FunctionResponsePart(responsePart.name, output)))
}
history.addAll(tempHistory)
return response
} finally {
lock.release()
Expand Down Expand Up @@ -101,43 +122,19 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
attemptLock()

val flow = model.generateContentStream(*history.toTypedArray(), prompt)
val bitmaps = LinkedList<Bitmap>()
val blobs = LinkedList<BlobPart>()
val text = StringBuilder()

val tempHistory = mutableListOf<Content>()
tempHistory.add(prompt)
/**
* TODO: revisit when images and blobs are returned. This will cause issues with how things are
* structured in the response. eg; a text/image/text response will be (incorrectly)
* represented as image/text
*/
return flow
.onEach {
for (part in it.candidates.first().content.parts) {
when (part) {
is TextPart -> text.append(part.text)
is ImagePart -> bitmaps.add(part.image)
is BlobPart -> blobs.add(part)
}
}
}
.transform { response -> automaticFunctionExecutingTransform(this, tempHistory, response) }
.onCompletion {
lock.release()
if (it == null) {
val content =
content("model") {
for (bitmap in bitmaps) {
image(bitmap)
}
for (blob in blobs) {
blob(blob.mimeType, blob.blob)
}
if (text.isNotBlank()) {
text(text.toString())
}
}

history.add(prompt)
history.add(content)
history.addAll(tempHistory)
}
}
}
Expand Down Expand Up @@ -167,9 +164,73 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
}

private fun Content.assertComesFromUser() {
if (role != "user") {
throw InvalidStateException("Chat prompts should come from the 'user' role.")
if (!VALID_ROLES.contains(role)) {
throw InvalidStateException("Chat prompts should come from the 'user' or 'function' role.")
}
}

private suspend fun automaticFunctionExecutingTransform(
transformer: FlowCollector<GenerateContentResponse>,
tempHistory: MutableList<Content>,
response: GenerateContentResponse,
) {
for (part in response.candidates.first().content.parts) {
when (part) {
is TextPart -> {
transformer.emit(response)
addTextToHistory(tempHistory, part)
}
is ImagePart -> {
transformer.emit(response)
tempHistory.add(Content("model", listOf(part)))
}
is BlobPart -> {
transformer.emit(response)
tempHistory.add(Content("model", listOf(part)))
}
is FunctionCallPart -> {
if (model.requestOptions.disableAutoFunction) {
tempHistory.add(response.candidates.first().content)
continue
}
val functionCall =
response.candidates.first().content.parts.first { it is FunctionCallPart }
as FunctionCallPart
val output = model.executeFunction(functionCall)
val functionResponse =
Content("function", listOf(FunctionResponsePart(functionCall.name, output)))
tempHistory.add(response.candidates.first().content)
tempHistory.add(functionResponse)
model
.generateContentStream(*history.toTypedArray(), *tempHistory.toTypedArray())
.collect { automaticFunctionExecutingTransform(transformer, tempHistory, it) }
}
}
}
}

private fun addTextToHistory(tempHistory: MutableList<Content>, textPart: TextPart) {
val lastContent = tempHistory.lastOrNull()
if (lastContent?.role == "model" && lastContent.parts.any { it is TextPart }) {
tempHistory.removeLast()
val editedContent =
Content(
"model",
lastContent.parts.map {
when (it) {
is TextPart -> {
TextPart(it.text + textPart.text)
}
else -> {
it
}
}
},
)
tempHistory.add(editedContent)
return
}
tempHistory.add(Content("model", listOf(textPart)))
}

private fun attemptLock() {
Expand Down
Loading