Skip to content

Commit

Permalink
remove auto function calling
Browse files Browse the repository at this point in the history
  • Loading branch information
David Motsonashvili committed Mar 28, 2024
1 parent c09e7ce commit c658b71
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 101 deletions.
134 changes: 37 additions & 97 deletions generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,16 @@ package com.google.ai.client.generativeai
import android.graphics.Bitmap
import com.google.ai.client.generativeai.type.BlobPart
import com.google.ai.client.generativeai.type.Content
import com.google.ai.client.generativeai.type.FunctionCallPart
import com.google.ai.client.generativeai.type.FunctionResponsePart
import com.google.ai.client.generativeai.type.GenerateContentResponse
import com.google.ai.client.generativeai.type.ImagePart
import com.google.ai.client.generativeai.type.InvalidStateException
import com.google.ai.client.generativeai.type.TextPart
import com.google.ai.client.generativeai.type.content
import java.util.LinkedList
import java.util.concurrent.Semaphore
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.FlowCollector
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.transform
import kotlinx.coroutines.flow.onEach

/**
* Representation of a back and forth interaction with a model.
Expand All @@ -55,27 +53,13 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
* @throws InvalidStateException if the prompt is not coming from the 'user' role
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
suspend fun sendMessage(inputPrompt: Content): GenerateContentResponse {
inputPrompt.assertComesFromUser()
suspend fun sendMessage(prompt: Content): GenerateContentResponse {
prompt.assertComesFromUser()
attemptLock()
var response: GenerateContentResponse
var prompt = inputPrompt
val tempHistory = mutableListOf<Content>()
try {
while (true) {
response =
model.generateContent(*history.toTypedArray(), *tempHistory.toTypedArray(), prompt)
val responsePart = response.candidates.first().content.parts.first()

tempHistory.add(prompt)
tempHistory.add(response.candidates.first().content)
if (responsePart !is FunctionCallPart) break
if (model.requestOptions.disableAutoFunction) break

val output = model.executeFunction(responsePart)
prompt = Content("function", listOf(FunctionResponsePart(responsePart.name, output)))
}
history.addAll(tempHistory)
val response = model.generateContent(*history.toTypedArray(), prompt)
history.add(prompt)
history.add(response.candidates.first().content)
return response
} finally {
lock.release()
Expand Down Expand Up @@ -117,19 +101,43 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
attemptLock()

val flow = model.generateContentStream(*history.toTypedArray(), prompt)
val tempHistory = mutableListOf<Content>()
tempHistory.add(prompt)
val bitmaps = LinkedList<Bitmap>()
val blobs = LinkedList<BlobPart>()
val text = StringBuilder()

/**
* TODO: revisit when images and blobs are returned. This will cause issues with how things are
* structured in the response. eg; a text/image/text response will be (incorrectly)
* represented as image/text
*/
return flow
.transform { response -> automaticFunctionExecutingTransform(this, tempHistory, response) }
.onEach {
for (part in it.candidates.first().content.parts) {
when (part) {
is TextPart -> text.append(part.text)
is ImagePart -> bitmaps.add(part.image)
is BlobPart -> blobs.add(part)
}
}
}
.onCompletion {
lock.release()
if (it == null) {
history.addAll(tempHistory)
val content =
content("model") {
for (bitmap in bitmaps) {
image(bitmap)
}
for (blob in blobs) {
blob(blob.mimeType, blob.blob)
}
if (text.isNotBlank()) {
text(text.toString())
}
}

history.add(prompt)
history.add(content)
}
}
}
Expand Down Expand Up @@ -159,75 +167,11 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
}

private fun Content.assertComesFromUser() {
if (role !in VALID_ROLES) {
throw InvalidStateException("Chat prompts should come from the 'user' or 'function' role.")
}
}

private suspend fun automaticFunctionExecutingTransform(
transformer: FlowCollector<GenerateContentResponse>,
tempHistory: MutableList<Content>,
response: GenerateContentResponse,
) {
for (part in response.candidates.first().content.parts) {
when (part) {
is TextPart -> {
transformer.emit(response)
addTextToHistory(tempHistory, part)
}
is ImagePart -> {
transformer.emit(response)
tempHistory.add(Content("model", listOf(part)))
}
is BlobPart -> {
transformer.emit(response)
tempHistory.add(Content("model", listOf(part)))
}
is FunctionCallPart -> {
if (model.requestOptions.disableAutoFunction) {
tempHistory.add(response.candidates.first().content)
continue
}
val functionCall =
response.candidates.first().content.parts.first { it is FunctionCallPart }
as FunctionCallPart
val output = model.executeFunction(functionCall)
val functionResponse =
Content("function", listOf(FunctionResponsePart(functionCall.name, output)))
tempHistory.add(response.candidates.first().content)
tempHistory.add(functionResponse)
model
.generateContentStream(*history.toTypedArray(), *tempHistory.toTypedArray())
.collect { automaticFunctionExecutingTransform(transformer, tempHistory, it) }
}
}
if (role != "user") {
throw InvalidStateException("Chat prompts should come from the 'user' role.")
}
}

private fun addTextToHistory(tempHistory: MutableList<Content>, textPart: TextPart) {
val lastContent = tempHistory.lastOrNull()
if (lastContent?.role == "model" && lastContent.parts.any { it is TextPart }) {
tempHistory.removeLast()
val editedContent =
Content(
"model",
lastContent.parts.map {
when (it) {
is TextPart -> {
TextPart(it.text + textPart.text)
}
else -> {
it
}
}
},
)
tempHistory.add(editedContent)
return
}
tempHistory.add(Content("model", listOf(textPart)))
}

private fun attemptLock() {
if (!lock.tryAcquire()) {
throw InvalidStateException(
Expand All @@ -236,8 +180,4 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
)
}
}

companion object {
private val VALID_ROLES = listOf("user", "function")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,17 @@ import kotlin.time.toDuration
* @property timeout the maximum amount of time for a request to take, from the first request to
* first response.
* @property apiVersion the api endpoint to call.
* @property disableAutoFunction if true, auto functions will not be automatically executed
*/
class RequestOptions(
val timeout: Duration,
val apiVersion: String = "v1",
val disableAutoFunction: Boolean = false,
) {
@JvmOverloads
constructor(
timeout: Long? = Long.MAX_VALUE,
apiVersion: String = "v1",
disableAutoFunction: Boolean = false,
) : this(
(timeout ?: Long.MAX_VALUE).toDuration(DurationUnit.MILLISECONDS),
apiVersion,
disableAutoFunction,
)
}

0 comments on commit c658b71

Please sign in to comment.