diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt index aae2e5ff..b536c244 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt @@ -19,18 +19,16 @@ package com.google.ai.client.generativeai import android.graphics.Bitmap import com.google.ai.client.generativeai.type.BlobPart import com.google.ai.client.generativeai.type.Content -import com.google.ai.client.generativeai.type.FunctionCallPart -import com.google.ai.client.generativeai.type.FunctionResponsePart import com.google.ai.client.generativeai.type.GenerateContentResponse import com.google.ai.client.generativeai.type.ImagePart import com.google.ai.client.generativeai.type.InvalidStateException import com.google.ai.client.generativeai.type.TextPart import com.google.ai.client.generativeai.type.content +import java.util.LinkedList import java.util.concurrent.Semaphore import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.FlowCollector import kotlinx.coroutines.flow.onCompletion -import kotlinx.coroutines.flow.transform +import kotlinx.coroutines.flow.onEach /** * Representation of a back and forth interaction with a model. @@ -55,27 +53,13 @@ class Chat(private val model: GenerativeModel, val history: MutableList * @throws InvalidStateException if the prompt is not coming from the 'user' role * @throws InvalidStateException if the [Chat] instance has an active request. */ - suspend fun sendMessage(inputPrompt: Content): GenerateContentResponse { - inputPrompt.assertComesFromUser() + suspend fun sendMessage(prompt: Content): GenerateContentResponse { + prompt.assertComesFromUser() attemptLock() - var response: GenerateContentResponse - var prompt = inputPrompt - val tempHistory = mutableListOf() try { - while (true) { - response = - model.generateContent(*history.toTypedArray(), *tempHistory.toTypedArray(), prompt) - val responsePart = response.candidates.first().content.parts.first() - - tempHistory.add(prompt) - tempHistory.add(response.candidates.first().content) - if (responsePart !is FunctionCallPart) break - if (model.requestOptions.disableAutoFunction) break - - val output = model.executeFunction(responsePart) - prompt = Content("function", listOf(FunctionResponsePart(responsePart.name, output))) - } - history.addAll(tempHistory) + val response = model.generateContent(*history.toTypedArray(), prompt) + history.add(prompt) + history.add(response.candidates.first().content) return response } finally { lock.release() @@ -117,19 +101,43 @@ class Chat(private val model: GenerativeModel, val history: MutableList attemptLock() val flow = model.generateContentStream(*history.toTypedArray(), prompt) - val tempHistory = mutableListOf() - tempHistory.add(prompt) + val bitmaps = LinkedList() + val blobs = LinkedList() + val text = StringBuilder() + /** * TODO: revisit when images and blobs are returned. This will cause issues with how things are * structured in the response. eg; a text/image/text response will be (incorrectly) * represented as image/text */ return flow - .transform { response -> automaticFunctionExecutingTransform(this, tempHistory, response) } + .onEach { + for (part in it.candidates.first().content.parts) { + when (part) { + is TextPart -> text.append(part.text) + is ImagePart -> bitmaps.add(part.image) + is BlobPart -> blobs.add(part) + } + } + } .onCompletion { lock.release() if (it == null) { - history.addAll(tempHistory) + val content = + content("model") { + for (bitmap in bitmaps) { + image(bitmap) + } + for (blob in blobs) { + blob(blob.mimeType, blob.blob) + } + if (text.isNotBlank()) { + text(text.toString()) + } + } + + history.add(prompt) + history.add(content) } } } @@ -159,75 +167,11 @@ class Chat(private val model: GenerativeModel, val history: MutableList } private fun Content.assertComesFromUser() { - if (role !in VALID_ROLES) { - throw InvalidStateException("Chat prompts should come from the 'user' or 'function' role.") - } - } - - private suspend fun automaticFunctionExecutingTransform( - transformer: FlowCollector, - tempHistory: MutableList, - response: GenerateContentResponse, - ) { - for (part in response.candidates.first().content.parts) { - when (part) { - is TextPart -> { - transformer.emit(response) - addTextToHistory(tempHistory, part) - } - is ImagePart -> { - transformer.emit(response) - tempHistory.add(Content("model", listOf(part))) - } - is BlobPart -> { - transformer.emit(response) - tempHistory.add(Content("model", listOf(part))) - } - is FunctionCallPart -> { - if (model.requestOptions.disableAutoFunction) { - tempHistory.add(response.candidates.first().content) - continue - } - val functionCall = - response.candidates.first().content.parts.first { it is FunctionCallPart } - as FunctionCallPart - val output = model.executeFunction(functionCall) - val functionResponse = - Content("function", listOf(FunctionResponsePart(functionCall.name, output))) - tempHistory.add(response.candidates.first().content) - tempHistory.add(functionResponse) - model - .generateContentStream(*history.toTypedArray(), *tempHistory.toTypedArray()) - .collect { automaticFunctionExecutingTransform(transformer, tempHistory, it) } - } - } + if (role != "user") { + throw InvalidStateException("Chat prompts should come from the 'user' role.") } } - private fun addTextToHistory(tempHistory: MutableList, textPart: TextPart) { - val lastContent = tempHistory.lastOrNull() - if (lastContent?.role == "model" && lastContent.parts.any { it is TextPart }) { - tempHistory.removeLast() - val editedContent = - Content( - "model", - lastContent.parts.map { - when (it) { - is TextPart -> { - TextPart(it.text + textPart.text) - } - else -> { - it - } - } - }, - ) - tempHistory.add(editedContent) - return - } - tempHistory.add(Content("model", listOf(textPart))) - } - private fun attemptLock() { if (!lock.tryAcquire()) { throw InvalidStateException( @@ -236,8 +180,4 @@ class Chat(private val model: GenerativeModel, val history: MutableList ) } } - - companion object { - private val VALID_ROLES = listOf("user", "function") - } } diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt index 147a9895..a2470ef2 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt @@ -26,21 +26,17 @@ import kotlin.time.toDuration * @property timeout the maximum amount of time for a request to take, from the first request to * first response. * @property apiVersion the api endpoint to call. - * @property disableAutoFunction if true, auto functions will not be automatically executed */ class RequestOptions( val timeout: Duration, val apiVersion: String = "v1", - val disableAutoFunction: Boolean = false, ) { @JvmOverloads constructor( timeout: Long? = Long.MAX_VALUE, apiVersion: String = "v1", - disableAutoFunction: Boolean = false, ) : this( (timeout ?: Long.MAX_VALUE).toDuration(DurationUnit.MILLISECONDS), apiVersion, - disableAutoFunction, ) }