Skip to content

Commit

Permalink
split the auto function calling changes between common and generativeai
Browse files Browse the repository at this point in the history
  • Loading branch information
David Motsonashvili committed Mar 21, 2024
1 parent 4638406 commit 58e46ab
Show file tree
Hide file tree
Showing 16 changed files with 674 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
package com.google.ai.client.generativeai.common

import com.google.ai.client.generativeai.common.client.GenerationConfig
import com.google.ai.client.generativeai.common.client.Tool
import com.google.ai.client.generativeai.common.shared.Content
import com.google.ai.client.generativeai.common.shared.SafetySetting

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

Expand All @@ -30,6 +32,7 @@ data class GenerateContentRequest(
val contents: List<Content>,
@SerialName("safety_settings") val safetySettings: List<SafetySetting>? = null,
@SerialName("generation_config") val generationConfig: GenerationConfig? = null,
val tools: List<Tool>? = null
) : Request

@Serializable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,24 @@ import kotlin.time.toDuration
* @property timeout the maximum amount of time for a request to take, from the first request to
* first response.
* @property apiVersion the api endpoint to call.
* @property disableAutoFunction if true, auto functions will not be automatically executed
*/
class RequestOptions(
val timeout: Duration,
val apiVersion: String = "v1",
val endpoint: String = "https://generativelanguage.googleapis.com"
val disableAutoFunction: Boolean = false,
val endpoint: String = "https://generativelanguage.googleapis.com",
) {
@JvmOverloads
constructor(
timeout: Long? = HttpTimeout.INFINITE_TIMEOUT_MS,
apiVersion: String = "v1",
endpoint: String = "https://generativelanguage.googleapis.com"
disableAutoFunction: Boolean = false,
endpoint: String = "https://generativelanguage.googleapis.com",
) : this(
(timeout ?: HttpTimeout.INFINITE_TIMEOUT_MS).toDuration(DurationUnit.MILLISECONDS),
apiVersion,
endpoint
disableAutoFunction,
endpoint,
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,23 @@ data class GenerationConfig(
@SerialName("max_output_tokens") val maxOutputTokens: Int?,
@SerialName("stop_sequences") val stopSequences: List<String>?
)

@Serializable data class Tool(val functionDeclarations: List<FunctionDeclaration>)

@Serializable
data class FunctionDeclaration(
val name: String,
val description: String,
val parameters: FunctionParameterProperties,
)

@Serializable
data class FunctionParameterProperties(
val type: String,
val description: String? = null,
val format: String? = null,
val enum: List<String>? = null,
val properties: Map<String, FunctionParameterProperties>? = null,
val required: List<String>? = null,
val items: FunctionParameterProperties? = null
)
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import kotlinx.serialization.Serializable
import kotlinx.serialization.SerializationException
import kotlinx.serialization.json.JsonContentPolymorphicSerializer
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.jsonObject

object HarmCategorySerializer :
Expand All @@ -52,12 +53,21 @@ data class Content(@EncodeDefault val role: String? = "user", val parts: List<Pa

@Serializable data class BlobPart(@SerialName("inline_data") val inlineData: Blob) : Part

@Serializable data class FunctionCallPart(val functionCall: FunctionCall) : Part

@Serializable
data class FunctionResponsePart(val functionResponse: FunctionResponse) : Part

@Serializable
data class Blob(
@SerialName("mime_type") val mimeType: String,
val data: Base64,
)

@Serializable data class FunctionResponse(val name: String, val response: JsonObject)

@Serializable data class FunctionCall(val name: String, val args: Map<String, String>)

@Serializable
data class SafetySetting(val category: HarmCategory, val threshold: HarmBlockThreshold)

Expand All @@ -76,6 +86,8 @@ object PartSerializer : JsonContentPolymorphicSerializer<Part>(Part::class) {
return when {
"text" in jsonObject -> TextPart.serializer()
"inlineData" in jsonObject -> BlobPart.serializer()
"functionCall" in jsonObject -> FunctionCallPart.serializer()
"functionResponse" in jsonObject -> FunctionResponsePart.serializer()
else -> throw SerializationException("Unknown Part type")
}
}
Expand Down
1 change: 1 addition & 0 deletions generativeai/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ dependencies {
implementation("org.slf4j:slf4j-nop:2.0.9")
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.7.3")
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-reactive:1.7.3")
implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1")
implementation("org.reactivestreams:reactive-streams:1.0.3")

implementation("com.google.guava:listenablefuture:1.0")
Expand Down
133 changes: 97 additions & 36 deletions generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package com.google.ai.client.generativeai
import android.graphics.Bitmap
import com.google.ai.client.generativeai.type.BlobPart
import com.google.ai.client.generativeai.type.Content
import com.google.ai.client.generativeai.type.FunctionCallPart
import com.google.ai.client.generativeai.type.FunctionResponsePart
import com.google.ai.client.generativeai.type.GenerateContentResponse
import com.google.ai.client.generativeai.type.ImagePart
import com.google.ai.client.generativeai.type.InvalidStateException
Expand All @@ -27,8 +29,9 @@ import com.google.ai.client.generativeai.type.content
import java.util.LinkedList
import java.util.concurrent.Semaphore
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.FlowCollector
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.flow.transform

/**
* Representation of a back and forth interaction with a model.
Expand All @@ -44,6 +47,10 @@ import kotlinx.coroutines.flow.onEach
*/
class Chat(private val model: GenerativeModel, val history: MutableList<Content> = ArrayList()) {
private var lock = Semaphore(1)
companion object{
private val VALID_ROLES = listOf("user", "function")
}


/**
* Generates a response from the backend with the provided [Content], and any previous ones
Expand All @@ -53,13 +60,27 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
* @throws InvalidStateException if the prompt is not coming from the 'user' role
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
suspend fun sendMessage(prompt: Content): GenerateContentResponse {
prompt.assertComesFromUser()
suspend fun sendMessage(inputPrompt: Content): GenerateContentResponse {
inputPrompt.assertComesFromUser()
attemptLock()
var response: GenerateContentResponse
var prompt = inputPrompt
val tempHistory = LinkedList<Content>()
try {
val response = model.generateContent(*history.toTypedArray(), prompt)
history.add(prompt)
history.add(response.candidates.first().content)
while (true) {
response =
model.generateContent(*history.toTypedArray(), *tempHistory.toTypedArray(), prompt)
val responsePart = response.candidates.first().content.parts.first()

tempHistory.add(prompt)
tempHistory.add(response.candidates.first().content)
if (responsePart !is FunctionCallPart) break
if (model.requestOptions.disableAutoFunction) break

val output = model.executeFunction(responsePart)
prompt = Content("function", listOf(FunctionResponsePart(responsePart.name, output)))
}
history.addAll(tempHistory)
return response
} finally {
lock.release()
Expand Down Expand Up @@ -101,43 +122,19 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
attemptLock()

val flow = model.generateContentStream(*history.toTypedArray(), prompt)
val bitmaps = LinkedList<Bitmap>()
val blobs = LinkedList<BlobPart>()
val text = StringBuilder()

val tempHistory = mutableListOf<Content>()
tempHistory.add(prompt)
/**
* TODO: revisit when images and blobs are returned. This will cause issues with how things are
* structured in the response. eg; a text/image/text response will be (incorrectly)
* represented as image/text
*/
return flow
.onEach {
for (part in it.candidates.first().content.parts) {
when (part) {
is TextPart -> text.append(part.text)
is ImagePart -> bitmaps.add(part.image)
is BlobPart -> blobs.add(part)
}
}
}
.transform { response -> automaticFunctionExecutingTransform(this, tempHistory, response) }
.onCompletion {
lock.release()
if (it == null) {
val content =
content("model") {
for (bitmap in bitmaps) {
image(bitmap)
}
for (blob in blobs) {
blob(blob.mimeType, blob.blob)
}
if (text.isNotBlank()) {
text(text.toString())
}
}

history.add(prompt)
history.add(content)
history.addAll(tempHistory)
}
}
}
Expand Down Expand Up @@ -167,9 +164,73 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
}

private fun Content.assertComesFromUser() {
if (role != "user") {
throw InvalidStateException("Chat prompts should come from the 'user' role.")
if (!VALID_ROLES.contains(role)) {
throw InvalidStateException("Chat prompts should come from the 'user' or 'function' role.")
}
}

private suspend fun automaticFunctionExecutingTransform(
transformer: FlowCollector<GenerateContentResponse>,
tempHistory: MutableList<Content>,
response: GenerateContentResponse,
) {
for (part in response.candidates.first().content.parts) {
when (part) {
is TextPart -> {
transformer.emit(response)
addTextToHistory(tempHistory, part)
}
is ImagePart -> {
transformer.emit(response)
tempHistory.add(Content("model", listOf(part)))
}
is BlobPart -> {
transformer.emit(response)
tempHistory.add(Content("model", listOf(part)))
}
is FunctionCallPart -> {
if (model.requestOptions.disableAutoFunction) {
tempHistory.add(response.candidates.first().content)
continue
}
val functionCall =
response.candidates.first().content.parts.first { it is FunctionCallPart }
as FunctionCallPart
val output = model.executeFunction(functionCall)
val functionResponse =
Content("function", listOf(FunctionResponsePart(functionCall.name, output)))
tempHistory.add(response.candidates.first().content)
tempHistory.add(functionResponse)
model
.generateContentStream(*history.toTypedArray(), *tempHistory.toTypedArray())
.collect { automaticFunctionExecutingTransform(transformer, tempHistory, it) }
}
}
}
}

private fun addTextToHistory(tempHistory: MutableList<Content>, textPart: TextPart) {
val lastContent = tempHistory.lastOrNull()
if (lastContent?.role == "model" && lastContent.parts.any { it is TextPart }) {
tempHistory.removeLast()
val editedContent =
Content(
"model",
lastContent.parts.map {
when (it) {
is TextPart -> {
TextPart(it.text + textPart.text)
}
else -> {
it
}
}
},
)
tempHistory.add(editedContent)
return
}
tempHistory.add(Content("model", listOf(textPart)))
}

private fun attemptLock() {
Expand Down
Loading

0 comments on commit 58e46ab

Please sign in to comment.