From f95f85da127bf9a0e680fb954f089698d29ab609 Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Tue, 30 Jan 2024 14:36:06 -0800 Subject: [PATCH 01/36] Intial implementation of automated function calling includes a builder to make it easier to add functions DO NOT MERGE --- .../com/google/ai/client/generativeai/Chat.kt | 118 ++++++++--- .../ai/client/generativeai/GenerativeModel.kt | 72 ++++++- .../internal/api/APIController.kt | 10 +- .../generativeai/internal/api/Request.kt | 2 + .../generativeai/internal/api/client/Types.kt | 17 ++ .../generativeai/internal/api/shared/Types.kt | 14 ++ .../generativeai/internal/util/conversions.kt | 52 +++++ .../generativeai/type/FunctionDeclarations.kt | 183 ++++++++++++++++++ .../generativeai/type/FunctionParameter.kt | 3 + .../ai/client/generativeai/type/Part.kt | 6 + .../ai/client/generativeai/type/Tool.kt | 11 ++ 11 files changed, 453 insertions(+), 35 deletions(-) create mode 100644 generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt create mode 100644 generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt create mode 100644 generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt index 2abe4b72..2e7a1185 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt @@ -19,6 +19,8 @@ package com.google.ai.client.generativeai import android.graphics.Bitmap import com.google.ai.client.generativeai.type.BlobPart import com.google.ai.client.generativeai.type.Content +import com.google.ai.client.generativeai.type.FunctionCallPart +import com.google.ai.client.generativeai.type.FunctionResponsePart import com.google.ai.client.generativeai.type.GenerateContentResponse import com.google.ai.client.generativeai.type.ImagePart import com.google.ai.client.generativeai.type.InvalidStateException @@ -27,8 +29,10 @@ import com.google.ai.client.generativeai.type.content import java.util.LinkedList import java.util.concurrent.Semaphore import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.FlowCollector import kotlinx.coroutines.flow.onCompletion import kotlinx.coroutines.flow.onEach +import kotlinx.coroutines.flow.transform /** * Representation of a back and forth interaction with a model. @@ -53,13 +57,25 @@ class Chat(private val model: GenerativeModel, val history: MutableList * @throws InvalidStateException if the prompt is not coming from the 'user' role * @throws InvalidStateException if the [Chat] instance has an active request. */ - suspend fun sendMessage(prompt: Content): GenerateContentResponse { - prompt.assertComesFromUser() + suspend fun sendMessage(inputPrompt: Content): GenerateContentResponse { + inputPrompt.assertComesFromUser() attemptLock() + var response: GenerateContentResponse + var prompt = inputPrompt try { - val response = model.generateContent(*history.toTypedArray(), prompt) - history.add(prompt) - history.add(response.candidates.first().content) + while (true) { + response = model.generateContent(*history.toTypedArray(), prompt) + val responsePart = response.candidates.first().content.parts.first() + + history.add(prompt) + history.add(response.candidates.first().content) + if (responsePart is FunctionCallPart) { + val output = model.executeFunction(responsePart) + prompt = Content("function", listOf(FunctionResponsePart(responsePart.name, output))) + } else { + break + } + } return response } finally { lock.release() @@ -101,43 +117,21 @@ class Chat(private val model: GenerativeModel, val history: MutableList attemptLock() val flow = model.generateContentStream(*history.toTypedArray(), prompt) - val bitmaps = LinkedList() - val blobs = LinkedList() - val text = StringBuilder() - + val tempHistory = LinkedList() + tempHistory.add(prompt) /** * TODO: revisit when images and blobs are returned. This will cause issues with how things are * structured in the response. eg; a text/image/text response will be (incorrectly) * represented as image/text */ return flow - .onEach { - for (part in it.candidates.first().content.parts) { - when (part) { - is TextPart -> text.append(part.text) - is ImagePart -> bitmaps.add(part.image) - is BlobPart -> blobs.add(part) - } - } + .transform { response -> + automaticFunctionExecutingTransform(this, tempHistory, response) } .onCompletion { lock.release() if (it == null) { - val content = - content("model") { - for (bitmap in bitmaps) { - image(bitmap) - } - for (blob in blobs) { - blob(blob.mimeType, blob.blob) - } - if (text.isNotBlank()) { - text(text.toString()) - } - } - - history.add(prompt) - history.add(content) + history.addAll(tempHistory) } } } @@ -172,6 +166,66 @@ class Chat(private val model: GenerativeModel, val history: MutableList } } + private suspend fun automaticFunctionExecutingTransform( + transformer: FlowCollector, + tempHistory: LinkedList, + response: GenerateContentResponse + ) { + for (part in response.candidates.first().content.parts) { + when (part) { + is TextPart -> { + transformer.emit(response) + addTextToHistory(tempHistory, part) + } + is ImagePart -> { + transformer.emit(response) + tempHistory.add(Content("model", listOf(part))) + } + is BlobPart -> { + transformer.emit(response) + tempHistory.add(Content("model", listOf(part))) + } + is FunctionCallPart -> { + val functionCall = + response.candidates.first().content.parts.first { it is FunctionCallPart } + as FunctionCallPart + val output = model.executeFunction(functionCall) + val functionResponse = + Content("function", listOf(FunctionResponsePart(functionCall.name, output))) + tempHistory.add(response.candidates.first().content) + tempHistory.add(functionResponse) + model + .generateContentStream(*history.toTypedArray(), *tempHistory.toTypedArray()) + .collect { automaticFunctionExecutingTransform(transformer, tempHistory, it) } + } + } + } + } + + private fun addTextToHistory(tempHistory: LinkedList, textPart: TextPart) { + val lastContent = tempHistory.lastOrNull() + if (lastContent?.role == "model" && lastContent.parts.any { it is TextPart }) { + tempHistory.removeLast() + val editedContent = + Content( + "model", + lastContent.parts.map { + when (it) { + is TextPart -> { + TextPart(it.text + textPart.text) + } + else -> { + it + } + } + } + ) + tempHistory.add(editedContent) + return + } + tempHistory.add(Content("model", listOf(textPart))) + } + private fun attemptLock() { if (!lock.tryAcquire()) { throw InvalidStateException( diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt index 15064aa7..6422ea9d 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt @@ -25,13 +25,20 @@ import com.google.ai.client.generativeai.internal.util.toPublic import com.google.ai.client.generativeai.type.Content import com.google.ai.client.generativeai.type.CountTokensResponse import com.google.ai.client.generativeai.type.FinishReason +import com.google.ai.client.generativeai.type.FourParameterFunction +import com.google.ai.client.generativeai.type.FunctionCallPart import com.google.ai.client.generativeai.type.GenerateContentResponse import com.google.ai.client.generativeai.type.GenerationConfig import com.google.ai.client.generativeai.type.GoogleGenerativeAIException +import com.google.ai.client.generativeai.type.NoParameterFunction +import com.google.ai.client.generativeai.type.OneParameterFunction import com.google.ai.client.generativeai.type.PromptBlockedException import com.google.ai.client.generativeai.type.ResponseStoppedException import com.google.ai.client.generativeai.type.SafetySetting import com.google.ai.client.generativeai.type.SerializationException +import com.google.ai.client.generativeai.type.ThreeParameterFunction +import com.google.ai.client.generativeai.type.Tool +import com.google.ai.client.generativeai.type.TwoParameterFunction import com.google.ai.client.generativeai.type.content import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.catch @@ -52,6 +59,7 @@ internal constructor( val apiKey: String, val generationConfig: GenerationConfig? = null, val safetySettings: List? = null, + val tools: List? = null, private val controller: APIController ) { @@ -61,7 +69,15 @@ internal constructor( apiKey: String, generationConfig: GenerationConfig? = null, safetySettings: List? = null, - ) : this(modelName, apiKey, generationConfig, safetySettings, APIController(apiKey, modelName)) + tools: List? = null + ) : this( + modelName, + apiKey, + generationConfig, + safetySettings, + tools, + APIController(apiKey, modelName) + ) /** * Generates a response from the backend with the provided [Content]s. @@ -160,12 +176,64 @@ internal constructor( return countTokens(content { image(prompt) }) } + /** + * Executes a function request by the model. + * + * @param call A [FunctionCallPart] from the model, containing a function call and parameters + * @return The output of the requested function call + */ + suspend fun executeFunction(call: FunctionCallPart): String { + if (tools == null) { + throw RuntimeException("No registered tools") + } + val tool = tools.first { it.functionDeclarations.any { it.name == call.name } } + val declaration = + tool.functionDeclarations.firstOrNull() { it.name == call.name } + ?: throw RuntimeException("No registered function named ${call.name}") + return when (declaration) { + is NoParameterFunction -> { + declaration.function.invoke() + } + is OneParameterFunction -> { + val param1 = getParamOrThrow(declaration.param.name, call) + declaration.function.invoke(param1) + } + is TwoParameterFunction -> { + val param1 = getParamOrThrow(declaration.param1.name, call) + val param2 = getParamOrThrow(declaration.param2.name, call) + declaration.function.invoke(param1, param2) + } + is ThreeParameterFunction -> { + val param1 = getParamOrThrow(declaration.param1.name, call) + val param2 = getParamOrThrow(declaration.param2.name, call) + val param3 = getParamOrThrow(declaration.param3.name, call) + declaration.function.invoke(param1, param2, param3) + } + is FourParameterFunction -> { + val param1 = getParamOrThrow(declaration.param1.name, call) + val param2 = getParamOrThrow(declaration.param2.name, call) + val param3 = getParamOrThrow(declaration.param3.name, call) + val param4 = getParamOrThrow(declaration.param4.name, call) + declaration.function.invoke(param1, param2, param3, param4) + } + else -> { + throw RuntimeException("UNREACHABLE") + } + } + } + + private fun getParamOrThrow(paramName: String, part: FunctionCallPart): String { + return part.args[paramName] + ?: throw RuntimeException("Missing parameter named $paramName for function ${part.name}") + } + private fun constructRequest(vararg prompt: Content) = GenerateContentRequest( modelName, prompt.map { it.toInternal() }, safetySettings?.map { it.toInternal() }, - generationConfig?.toInternal() + generationConfig?.toInternal(), + tools?.map { it.toInternal() } ) private fun constructCountTokensRequest(vararg prompt: Content) = diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt index e46016ca..faf64cb1 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt @@ -25,6 +25,10 @@ import io.ktor.client.engine.HttpClientEngine import io.ktor.client.engine.okhttp.OkHttp import io.ktor.client.plugins.HttpTimeout import io.ktor.client.plugins.contentnegotiation.ContentNegotiation +import io.ktor.client.plugins.logging.ANDROID +import io.ktor.client.plugins.logging.LogLevel +import io.ktor.client.plugins.logging.Logger +import io.ktor.client.plugins.logging.Logging import io.ktor.client.request.HttpRequestBuilder import io.ktor.client.request.header import io.ktor.client.request.post @@ -44,7 +48,7 @@ import kotlinx.coroutines.launch import kotlinx.serialization.json.Json // TODO: Should these stay here or be moved elsewhere? -internal const val DOMAIN = "https://generativelanguage.googleapis.com/v1" +internal const val DOMAIN = "https://generativelanguage.googleapis.com/v1beta" internal val JSON = Json { ignoreUnknownKeys = true @@ -74,6 +78,10 @@ internal class APIController( requestTimeoutMillis = HttpTimeout.INFINITE_TIMEOUT_MS socketTimeoutMillis = 80_000 } + install(Logging) { + logger = Logger.ANDROID + level = LogLevel.BODY + } install(ContentNegotiation) { json(JSON) } } diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/Request.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/Request.kt index 44cbe85d..1e98d2f6 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/Request.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/Request.kt @@ -17,6 +17,7 @@ package com.google.ai.client.generativeai.internal.api import com.google.ai.client.generativeai.internal.api.client.GenerationConfig +import com.google.ai.client.generativeai.internal.api.client.Tool import com.google.ai.client.generativeai.internal.api.shared.Content import com.google.ai.client.generativeai.internal.api.shared.SafetySetting import kotlinx.serialization.SerialName @@ -30,6 +31,7 @@ internal data class GenerateContentRequest( val contents: List, @SerialName("safety_settings") val safetySettings: List? = null, @SerialName("generation_config") val generationConfig: GenerationConfig? = null, + val tools: List? = null ) : Request @Serializable diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/client/Types.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/client/Types.kt index f1f778bf..a50db908 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/client/Types.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/client/Types.kt @@ -18,6 +18,7 @@ package com.google.ai.client.generativeai.internal.api.client import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonObject @Serializable internal data class GenerationConfig( @@ -28,3 +29,19 @@ internal data class GenerationConfig( @SerialName("max_output_tokens") val maxOutputTokens: Int?, @SerialName("stop_sequences") val stopSequences: List? ) + +@Serializable internal data class Tool(val functionDeclarations: List) + +@Serializable +internal data class FunctionDeclaration( + val name: String, + val description: String, + val parameters: FunctionParameters +) + +@Serializable +internal data class FunctionParameters( + val properties: JsonObject, + val required: List, + val type: String, +) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/shared/Types.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/shared/Types.kt index 3f15f807..1993d8d1 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/shared/Types.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/shared/Types.kt @@ -48,12 +48,24 @@ typealias Base64 = String @Serializable internal data class BlobPart(@SerialName("inline_data") val inlineData: Blob) : Part +@Serializable internal data class FunctionCallPart(val functionCall: FunctionCall) : Part + +@Serializable +internal data class FunctionResponsePart(val functionResponse: FunctionResponse) : Part + @Serializable internal data class Blob( @SerialName("mime_type") val mimeType: String, val data: Base64, ) +@Serializable +internal data class FunctionResponse(val name: String, val response: FunctionResponseData) + +@Serializable internal data class FunctionCall(val name: String, val args: Map) + +@Serializable internal data class FunctionResponseData(val name: String, val content: String) + @Serializable internal data class SafetySetting(val category: HarmCategory, val threshold: HarmBlockThreshold) @@ -72,6 +84,8 @@ internal object PartSerializer : JsonContentPolymorphicSerializer(Part::cl return when { "text" in jsonObject -> TextPart.serializer() "inlineData" in jsonObject -> BlobPart.serializer() + "functionCall" in jsonObject -> FunctionCallPart.serializer() + "functionResponse" in jsonObject -> FunctionResponsePart.serializer() else -> throw SerializationException("Unknown Part type") } } diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt index df0d5285..0180df4b 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt @@ -32,6 +32,11 @@ import com.google.ai.client.generativeai.internal.api.server.SafetyRating import com.google.ai.client.generativeai.internal.api.shared.Blob import com.google.ai.client.generativeai.internal.api.shared.BlobPart import com.google.ai.client.generativeai.internal.api.shared.Content +import com.google.ai.client.generativeai.internal.api.shared.FunctionCall +import com.google.ai.client.generativeai.internal.api.shared.FunctionCallPart +import com.google.ai.client.generativeai.internal.api.shared.FunctionResponse +import com.google.ai.client.generativeai.internal.api.shared.FunctionResponseData +import com.google.ai.client.generativeai.internal.api.shared.FunctionResponsePart import com.google.ai.client.generativeai.internal.api.shared.HarmBlockThreshold import com.google.ai.client.generativeai.internal.api.shared.HarmCategory import com.google.ai.client.generativeai.internal.api.shared.Part @@ -39,10 +44,14 @@ import com.google.ai.client.generativeai.internal.api.shared.SafetySetting import com.google.ai.client.generativeai.internal.api.shared.TextPart import com.google.ai.client.generativeai.type.BlockThreshold import com.google.ai.client.generativeai.type.CitationMetadata +import com.google.ai.client.generativeai.type.FunctionDeclaration import com.google.ai.client.generativeai.type.ImagePart import com.google.ai.client.generativeai.type.SerializationException +import com.google.ai.client.generativeai.type.Tool import com.google.ai.client.generativeai.type.content import java.io.ByteArrayOutputStream +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject private const val BASE_64_FLAGS = Base64.NO_WRAP @@ -55,6 +64,10 @@ internal fun com.google.ai.client.generativeai.type.Part.toInternal(): Part { is ImagePart -> BlobPart(Blob("image/jpeg", encodeBitmapToBase64Png(image))) is com.google.ai.client.generativeai.type.BlobPart -> BlobPart(Blob(mimeType, Base64.encodeToString(blob, BASE_64_FLAGS))) + is com.google.ai.client.generativeai.type.FunctionCallPart -> + FunctionCallPart(FunctionCall(name, args)) + is com.google.ai.client.generativeai.type.FunctionResponsePart -> + FunctionResponsePart(FunctionResponse(name, FunctionResponseData(name, response))) else -> throw SerializationException( "The given subclass of Part (${javaClass.simpleName}) is not supported in the serialization yet." @@ -95,6 +108,35 @@ internal fun BlockThreshold.toInternal() = BlockThreshold.UNSPECIFIED -> HarmBlockThreshold.UNSPECIFIED } +internal fun Tool.toInternal() = + com.google.ai.client.generativeai.internal.api.client.Tool( + functionDeclarations.map { it.toInternal() } + ) + +internal fun FunctionDeclaration.toInternal(): + com.google.ai.client.generativeai.internal.api.client.FunctionDeclaration { + val convertedParams = buildJsonObject { + getParameters().forEach { + put( + it.name, + buildJsonObject { + put("type", JsonPrimitive("STRING")) + put("description", JsonPrimitive(it.description)) + } + ) + } + } + return com.google.ai.client.generativeai.internal.api.client.FunctionDeclaration( + name, + description, + com.google.ai.client.generativeai.internal.api.client.FunctionParameters( + convertedParams, + getParameters().map { it.name }, + "OBJECT" + ) + ) +} + internal fun Candidate.toPublic(): com.google.ai.client.generativeai.type.Candidate { val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty() val citations = citationMetadata?.citationSources?.map { it.toPublic() }.orEmpty() @@ -122,6 +164,16 @@ internal fun Part.toPublic(): com.google.ai.client.generativeai.type.Part { com.google.ai.client.generativeai.type.BlobPart(inlineData.mimeType, data) } } + is FunctionCallPart -> + com.google.ai.client.generativeai.type.FunctionCallPart( + functionCall.name, + functionCall.args.orEmpty() + ) + is FunctionResponsePart -> + com.google.ai.client.generativeai.type.FunctionResponsePart( + functionResponse.name, + functionResponse.response.content + ) } } diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt new file mode 100644 index 00000000..1871e8f7 --- /dev/null +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt @@ -0,0 +1,183 @@ +package com.google.ai.client.generativeai.type + +/** + * A declared function, including implementation, that a model can be given access to in order to + * gain info or complete tasks. + * + * @property name The name of the function call, this should be clear and descriptive for the model + * @property description A description of what the function does and its output. + * @property function the function implementation + */ +class NoParameterFunction( + name: String, + description: String, + val function: suspend () -> String, +) : FunctionDeclaration(name, description) { + override fun getParameters() = listOf() +} + +/** + * A declared function, including implementation, that a model can be given access to in order to + * gain info or complete tasks. + * + * @property name The name of the function call, this should be clear and descriptive for the model + * @property description A description of what the function does and its output. + * @property param A description of the first function parameter + * @property function the function implementation + */ +class OneParameterFunction( + name: String, + description: String, + val param: FunctionParameter, + val function: suspend (String) -> String, +) : FunctionDeclaration(name, description) { + override fun getParameters() = listOf(param) +} + +/** + * A declared function, including implementation, that a model can be given access to in order to + * gain info or complete tasks. + * + * @property name The name of the function call, this should be clear and descriptive for the model + * @property description A description of what the function does and its output. + * @property param1 A description of the first function parameter + * @property param2 A description of the second function parameter + * @property function the function implementation + */ +class TwoParameterFunction( + name: String, + description: String, + val param1: FunctionParameter, + val param2: FunctionParameter, + val function: suspend (String, String) -> String, +) : FunctionDeclaration(name, description) { + override fun getParameters() = listOf(param1, param2) +} + +/** + * A declared function, including implementation, that a model can be given access to in order to + * gain info or complete tasks. + * + * @property name The name of the function call, this should be clear and descriptive for the model + * @property description A description of what the function does and its output. + * @property param1 A description of the first function parameter + * @property param2 A description of the second function parameter + * @property param3 A description of the third function parameter + * @property function the function implementation + */ +class ThreeParameterFunction( + name: String, + description: String, + val param1: FunctionParameter, + val param2: FunctionParameter, + val param3: FunctionParameter, + val function: suspend (String, String, String) -> String, +) : FunctionDeclaration(name, description) { + override fun getParameters() = listOf(param1, param2, param3) +} + +/** + * A declared function, including implementation, that a model can be given access to in order to + * gain info or complete tasks. + * + * @property name The name of the function call, this should be clear and descriptive for the model + * @property description A description of what the function does and its output. + * @property param1 A description of the first function parameter + * @property param2 A description of the second function parameter + * @property param3 A description of the third function parameter + * @property param4 A description of the fourth function parameter + * @property function the function implementation + */ +class FourParameterFunction( + name: String, + description: String, + val param1: FunctionParameter, + val param2: FunctionParameter, + val param3: FunctionParameter, + val param4: FunctionParameter, + val function: suspend (String, String, String, String) -> String, +) : FunctionDeclaration(name, description) { + override fun getParameters() = listOf(param1, param2, param3, param4) +} + +abstract class FunctionDeclaration( + val name: String, + val description: String, +) { + abstract fun getParameters(): List +} + +/** + * A builder to help build [FunctionDeclaration] objects + * + * @property name The name of the function call, this should be clear and descriptive for the model + * @property description A description of what the function does and its output. + */ +class FunctionBuilder(private val name: String, private val description: String) { + + fun build(function: suspend () -> String): FunctionDeclaration { + return NoParameterFunction(name, description, function) + } + + fun param(param: FunctionParameter): OneFunctionBuilder { + return OneFunctionBuilder(name, description, param) + } +} + +class OneFunctionBuilder( + private val name: String, + private val description: String, + private val param1: FunctionParameter +) { + fun build(function: suspend (String) -> String): FunctionDeclaration { + return OneParameterFunction(name, description, param1, function) + } + + fun param(param: FunctionParameter): TwoFunctionBuilder { + return TwoFunctionBuilder(name, description, param1, param) + } +} + +class TwoFunctionBuilder( + private val name: String, + private val description: String, + private val param1: FunctionParameter, + private val param2: FunctionParameter, +) { + fun build(function: suspend (String, String) -> String): FunctionDeclaration { + return TwoParameterFunction(name, description, param1, param2, function) + } + + fun param(param: FunctionParameter): ThreeFunctionBuilder { + return ThreeFunctionBuilder(name, description, param1, param2, param) + } +} + +class ThreeFunctionBuilder( + private val name: String, + private val description: String, + private val param1: FunctionParameter, + private val param2: FunctionParameter, + private val param3: FunctionParameter, +) { + fun build(function: suspend (String, String, String) -> String): FunctionDeclaration { + return ThreeParameterFunction(name, description, param1, param2, param3, function) + } + + fun param(param: FunctionParameter): FourFunctionBuilder { + return FourFunctionBuilder(name, description, param1, param2, param3, param) + } +} + +class FourFunctionBuilder( + private val name: String, + private val description: String, + private val param1: FunctionParameter, + private val param2: FunctionParameter, + private val param3: FunctionParameter, + private val param4: FunctionParameter, +) { + fun build(function: suspend (String, String, String, String) -> String): FunctionDeclaration { + return FourParameterFunction(name, description, param1, param2, param3, param4, function) + } +} diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt new file mode 100644 index 00000000..cf233cef --- /dev/null +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt @@ -0,0 +1,3 @@ +package com.google.ai.client.generativeai.type + +class FunctionParameter(val name: String, val description: String, val type: String) {} diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt index 4c54cc43..ebdde06b 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt @@ -40,6 +40,12 @@ class ImagePart(val image: Bitmap) : Part /** Represents binary data with an associated MIME type sent to and received from requests. */ class BlobPart(val mimeType: String, val blob: ByteArray) : Part +/** Represents function call name and params received from requests. */ +class FunctionCallPart(val name: String, val args: Map) : Part + +/** Represents function call output to be returned to the model when it requests a function call */ +class FunctionResponsePart(val name: String, val response: String) : Part + /** @return The part as a [String] if it represents text, and null otherwise */ fun Part.asTextOrNull(): String? = (this as? TextPart)?.text diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt new file mode 100644 index 00000000..1e3344fa --- /dev/null +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt @@ -0,0 +1,11 @@ +package com.google.ai.client.generativeai.type + +/** + * Contains a set of function declarations that the model has access to. These can be used to gather + * information, or complete tasks + * + * @param functionDeclarations The set of functions that this tool allows the model access to + */ +class Tool( + val functionDeclarations: List, +) From 34287a8f239c0c45bdf50af547cc073f01eee90f Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Thu, 1 Feb 2024 10:08:49 -0800 Subject: [PATCH 02/36] Add opt in annotation for v1beta endpoint --- .../java/com/google/ai/client/generativeai/Chat.kt | 2 ++ .../google/ai/client/generativeai/GenerativeModel.kt | 3 +++ .../client/generativeai/internal/util/conversions.kt | 3 +++ .../ai/client/generativeai/type/BetaGenAiAPI.kt | 6 ++++++ .../client/generativeai/type/FunctionDeclarations.kt | 11 +++++++++++ .../com/google/ai/client/generativeai/type/Tool.kt | 1 + 6 files changed, 26 insertions(+) create mode 100644 generativeai/src/main/java/com/google/ai/client/generativeai/type/BetaGenAiAPI.kt diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt index 2e7a1185..a8153f24 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt @@ -17,6 +17,7 @@ package com.google.ai.client.generativeai import android.graphics.Bitmap +import com.google.ai.client.generativeai.type.BetaGenAiAPI import com.google.ai.client.generativeai.type.BlobPart import com.google.ai.client.generativeai.type.Content import com.google.ai.client.generativeai.type.FunctionCallPart @@ -46,6 +47,7 @@ import kotlinx.coroutines.flow.transform * @param model the model to use for the interaction * @property history the previous interactions with the model */ +@OptIn(BetaGenAiAPI::class) class Chat(private val model: GenerativeModel, val history: MutableList = ArrayList()) { private var lock = Semaphore(1) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt index 6422ea9d..5db9a646 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt @@ -22,6 +22,7 @@ import com.google.ai.client.generativeai.internal.api.CountTokensRequest import com.google.ai.client.generativeai.internal.api.GenerateContentRequest import com.google.ai.client.generativeai.internal.util.toInternal import com.google.ai.client.generativeai.internal.util.toPublic +import com.google.ai.client.generativeai.type.BetaGenAiAPI import com.google.ai.client.generativeai.type.Content import com.google.ai.client.generativeai.type.CountTokensResponse import com.google.ai.client.generativeai.type.FinishReason @@ -53,6 +54,7 @@ import kotlinx.coroutines.flow.map * @property safetySettings the safety bounds to use during alongside prompts during content * generation */ +@OptIn(BetaGenAiAPI::class) class GenerativeModel internal constructor( val modelName: String, @@ -182,6 +184,7 @@ internal constructor( * @param call A [FunctionCallPart] from the model, containing a function call and parameters * @return The output of the requested function call */ + @BetaGenAiAPI suspend fun executeFunction(call: FunctionCallPart): String { if (tools == null) { throw RuntimeException("No registered tools") diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt index 0180df4b..010ca807 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt @@ -42,6 +42,7 @@ import com.google.ai.client.generativeai.internal.api.shared.HarmCategory import com.google.ai.client.generativeai.internal.api.shared.Part import com.google.ai.client.generativeai.internal.api.shared.SafetySetting import com.google.ai.client.generativeai.internal.api.shared.TextPart +import com.google.ai.client.generativeai.type.BetaGenAiAPI import com.google.ai.client.generativeai.type.BlockThreshold import com.google.ai.client.generativeai.type.CitationMetadata import com.google.ai.client.generativeai.type.FunctionDeclaration @@ -108,11 +109,13 @@ internal fun BlockThreshold.toInternal() = BlockThreshold.UNSPECIFIED -> HarmBlockThreshold.UNSPECIFIED } +@BetaGenAiAPI internal fun Tool.toInternal() = com.google.ai.client.generativeai.internal.api.client.Tool( functionDeclarations.map { it.toInternal() } ) +@BetaGenAiAPI internal fun FunctionDeclaration.toInternal(): com.google.ai.client.generativeai.internal.api.client.FunctionDeclaration { val convertedParams = buildJsonObject { diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/BetaGenAiAPI.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/BetaGenAiAPI.kt new file mode 100644 index 00000000..2bd4b92c --- /dev/null +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/BetaGenAiAPI.kt @@ -0,0 +1,6 @@ +package com.google.ai.client.generativeai.type + +@RequiresOptIn(message = "This API is only available on the v1beta endpoint") +@Retention(AnnotationRetention.BINARY) +@Target(AnnotationTarget.CLASS, AnnotationTarget.FUNCTION) +annotation class BetaGenAiAPI diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt index 1871e8f7..9a83e87b 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt @@ -8,6 +8,7 @@ package com.google.ai.client.generativeai.type * @property description A description of what the function does and its output. * @property function the function implementation */ +@BetaGenAiAPI class NoParameterFunction( name: String, description: String, @@ -25,6 +26,7 @@ class NoParameterFunction( * @property param A description of the first function parameter * @property function the function implementation */ +@BetaGenAiAPI class OneParameterFunction( name: String, description: String, @@ -44,6 +46,7 @@ class OneParameterFunction( * @property param2 A description of the second function parameter * @property function the function implementation */ +@BetaGenAiAPI class TwoParameterFunction( name: String, description: String, @@ -65,6 +68,7 @@ class TwoParameterFunction( * @property param3 A description of the third function parameter * @property function the function implementation */ +@BetaGenAiAPI class ThreeParameterFunction( name: String, description: String, @@ -88,6 +92,7 @@ class ThreeParameterFunction( * @property param4 A description of the fourth function parameter * @property function the function implementation */ +@BetaGenAiAPI class FourParameterFunction( name: String, description: String, @@ -100,6 +105,7 @@ class FourParameterFunction( override fun getParameters() = listOf(param1, param2, param3, param4) } +@BetaGenAiAPI abstract class FunctionDeclaration( val name: String, val description: String, @@ -113,6 +119,7 @@ abstract class FunctionDeclaration( * @property name The name of the function call, this should be clear and descriptive for the model * @property description A description of what the function does and its output. */ +@BetaGenAiAPI class FunctionBuilder(private val name: String, private val description: String) { fun build(function: suspend () -> String): FunctionDeclaration { @@ -124,6 +131,7 @@ class FunctionBuilder(private val name: String, private val description: String) } } +@BetaGenAiAPI class OneFunctionBuilder( private val name: String, private val description: String, @@ -138,6 +146,7 @@ class OneFunctionBuilder( } } +@BetaGenAiAPI class TwoFunctionBuilder( private val name: String, private val description: String, @@ -153,6 +162,7 @@ class TwoFunctionBuilder( } } +@BetaGenAiAPI class ThreeFunctionBuilder( private val name: String, private val description: String, @@ -169,6 +179,7 @@ class ThreeFunctionBuilder( } } +@BetaGenAiAPI class FourFunctionBuilder( private val name: String, private val description: String, diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt index 1e3344fa..598eb52f 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt @@ -6,6 +6,7 @@ package com.google.ai.client.generativeai.type * * @param functionDeclarations The set of functions that this tool allows the model access to */ +@BetaGenAiAPI class Tool( val functionDeclarations: List, ) From 1d8ebf441e06735d457cd83606ac82670991ec66 Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Thu, 1 Feb 2024 10:14:17 -0800 Subject: [PATCH 03/36] ktfmt and add licenses --- .../com/google/ai/client/generativeai/Chat.kt | 5 +---- .../ai/client/generativeai/type/BetaGenAiAPI.kt | 16 ++++++++++++++++ .../generativeai/type/FunctionDeclarations.kt | 16 ++++++++++++++++ .../generativeai/type/FunctionParameter.kt | 16 ++++++++++++++++ .../google/ai/client/generativeai/type/Tool.kt | 16 ++++++++++++++++ 5 files changed, 65 insertions(+), 4 deletions(-) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt index a8153f24..8defeefd 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt @@ -32,7 +32,6 @@ import java.util.concurrent.Semaphore import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.FlowCollector import kotlinx.coroutines.flow.onCompletion -import kotlinx.coroutines.flow.onEach import kotlinx.coroutines.flow.transform /** @@ -127,9 +126,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList * represented as image/text */ return flow - .transform { response -> - automaticFunctionExecutingTransform(this, tempHistory, response) - } + .transform { response -> automaticFunctionExecutingTransform(this, tempHistory, response) } .onCompletion { lock.release() if (it == null) { diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/BetaGenAiAPI.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/BetaGenAiAPI.kt index 2bd4b92c..42bc9beb 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/BetaGenAiAPI.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/BetaGenAiAPI.kt @@ -1,3 +1,19 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package com.google.ai.client.generativeai.type @RequiresOptIn(message = "This API is only available on the v1beta endpoint") diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt index 9a83e87b..b7f4a17c 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt @@ -1,3 +1,19 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package com.google.ai.client.generativeai.type /** diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt index cf233cef..7ae83a65 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt @@ -1,3 +1,19 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package com.google.ai.client.generativeai.type class FunctionParameter(val name: String, val description: String, val type: String) {} diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt index 598eb52f..c4a961af 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt @@ -1,3 +1,19 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package com.google.ai.client.generativeai.type /** From 61e314001f71bfcbd9a93050195ed7e0b54a0e4f Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Thu, 1 Feb 2024 10:19:41 -0800 Subject: [PATCH 04/36] fix behavior for non-streaming chat and fix opt in propagation --- .../main/java/com/google/ai/client/generativeai/Chat.kt | 8 +++++--- .../java/com/google/ai/client/generativeai/type/Tool.kt | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt index 8defeefd..2b7119ec 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt @@ -63,13 +63,14 @@ class Chat(private val model: GenerativeModel, val history: MutableList attemptLock() var response: GenerateContentResponse var prompt = inputPrompt + val tempHistory = LinkedList() try { while (true) { - response = model.generateContent(*history.toTypedArray(), prompt) + response = model.generateContent(*history.toTypedArray(), *tempHistory.toTypedArray(), prompt) val responsePart = response.candidates.first().content.parts.first() - history.add(prompt) - history.add(response.candidates.first().content) + tempHistory.add(prompt) + tempHistory.add(response.candidates.first().content) if (responsePart is FunctionCallPart) { val output = model.executeFunction(responsePart) prompt = Content("function", listOf(FunctionResponsePart(responsePart.name, output))) @@ -77,6 +78,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList break } } + history.addAll(tempHistory) return response } finally { lock.release() diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt index c4a961af..404a278b 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt @@ -22,7 +22,7 @@ package com.google.ai.client.generativeai.type * * @param functionDeclarations The set of functions that this tool allows the model access to */ -@BetaGenAiAPI +@OptIn(BetaGenAiAPI::class) class Tool( val functionDeclarations: List, ) From 8c5c79815127db8f68dce135af13f60a2f02b120 Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Thu, 1 Feb 2024 10:21:43 -0800 Subject: [PATCH 05/36] ktfmt --- .../src/main/java/com/google/ai/client/generativeai/Chat.kt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt index 2b7119ec..be304125 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt @@ -66,7 +66,8 @@ class Chat(private val model: GenerativeModel, val history: MutableList val tempHistory = LinkedList() try { while (true) { - response = model.generateContent(*history.toTypedArray(), *tempHistory.toTypedArray(), prompt) + response = + model.generateContent(*history.toTypedArray(), *tempHistory.toTypedArray(), prompt) val responsePart = response.candidates.first().content.parts.first() tempHistory.add(prompt) From 7cad4caa07a2e85cdf8f030e6ad37399dbe8501e Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Thu, 1 Feb 2024 12:50:39 -0800 Subject: [PATCH 06/36] remove logging --- .../ai/client/generativeai/internal/api/APIController.kt | 4 ---- 1 file changed, 4 deletions(-) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt index faf64cb1..c58a625e 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt @@ -78,10 +78,6 @@ internal class APIController( requestTimeoutMillis = HttpTimeout.INFINITE_TIMEOUT_MS socketTimeoutMillis = 80_000 } - install(Logging) { - logger = Logger.ANDROID - level = LogLevel.BODY - } install(ContentNegotiation) { json(JSON) } } From 1136610a179332b5ac67d35160225294b7bff324 Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Wed, 7 Feb 2024 12:45:25 -0800 Subject: [PATCH 07/36] added convinience methods for getting function call and response types, and added a type system for automated function calling --- .../ai/client/generativeai/GenerativeModel.kt | 59 ++-- .../internal/api/APIController.kt | 4 - .../generativeai/type/FunctionDeclarations.kt | 260 +++++++++++++++--- .../generativeai/type/FunctionParameter.kt | 2 +- .../type/GenerateContentResponse.kt | 6 + .../ai/client/generativeai/type/Type.kt | 9 + 6 files changed, 271 insertions(+), 69 deletions(-) create mode 100644 generativeai/src/main/java/com/google/ai/client/generativeai/type/Type.kt diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt index 5db9a646..72d5e6c3 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt @@ -28,6 +28,7 @@ import com.google.ai.client.generativeai.type.CountTokensResponse import com.google.ai.client.generativeai.type.FinishReason import com.google.ai.client.generativeai.type.FourParameterFunction import com.google.ai.client.generativeai.type.FunctionCallPart +import com.google.ai.client.generativeai.type.FunctionParameter import com.google.ai.client.generativeai.type.GenerateContentResponse import com.google.ai.client.generativeai.type.GenerationConfig import com.google.ai.client.generativeai.type.GoogleGenerativeAIException @@ -197,27 +198,43 @@ internal constructor( is NoParameterFunction -> { declaration.function.invoke() } - is OneParameterFunction -> { - val param1 = getParamOrThrow(declaration.param.name, call) - declaration.function.invoke(param1) + is OneParameterFunction<*> -> { + declaration + .let { declaration as OneParameterFunction } + .let { + val param1 = getParamOrThrow(it.param, call) + it.function.invoke(param1) + } } - is TwoParameterFunction -> { - val param1 = getParamOrThrow(declaration.param1.name, call) - val param2 = getParamOrThrow(declaration.param2.name, call) - declaration.function.invoke(param1, param2) + is TwoParameterFunction<*, *> -> { + declaration + .let { declaration as TwoParameterFunction } + .let { + val param1 = getParamOrThrow(it.param1, call) + val param2 = getParamOrThrow(it.param2, call) + it.function.invoke(param1, param2) + } } - is ThreeParameterFunction -> { - val param1 = getParamOrThrow(declaration.param1.name, call) - val param2 = getParamOrThrow(declaration.param2.name, call) - val param3 = getParamOrThrow(declaration.param3.name, call) - declaration.function.invoke(param1, param2, param3) + is ThreeParameterFunction<*, *, *> -> { + declaration + .let { declaration as ThreeParameterFunction } + .let { + val param1 = getParamOrThrow(it.param1, call) + val param2 = getParamOrThrow(it.param2, call) + val param3 = getParamOrThrow(it.param3, call) + it.function.invoke(param1, param2, param3) + } } - is FourParameterFunction -> { - val param1 = getParamOrThrow(declaration.param1.name, call) - val param2 = getParamOrThrow(declaration.param2.name, call) - val param3 = getParamOrThrow(declaration.param3.name, call) - val param4 = getParamOrThrow(declaration.param4.name, call) - declaration.function.invoke(param1, param2, param3, param4) + is FourParameterFunction<*, *, *, *> -> { + declaration + .let { declaration as FourParameterFunction } + .let { + val param1 = getParamOrThrow(it.param1, call) + val param2 = getParamOrThrow(it.param2, call) + val param3 = getParamOrThrow(it.param3, call) + val param4 = getParamOrThrow(it.param4, call) + it.function.invoke(param1, param2, param3, param4) + } } else -> { throw RuntimeException("UNREACHABLE") @@ -225,9 +242,9 @@ internal constructor( } } - private fun getParamOrThrow(paramName: String, part: FunctionCallPart): String { - return part.args[paramName] - ?: throw RuntimeException("Missing parameter named $paramName for function ${part.name}") + private fun getParamOrThrow(param: FunctionParameter, part: FunctionCallPart): T { + return param.type.parse.invoke(part.args[param.name]) + ?: throw RuntimeException("Missing parameter named ${param.name} for function ${part.name}") } private fun constructRequest(vararg prompt: Content) = diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt index c58a625e..aef3ab93 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt @@ -25,10 +25,6 @@ import io.ktor.client.engine.HttpClientEngine import io.ktor.client.engine.okhttp.OkHttp import io.ktor.client.plugins.HttpTimeout import io.ktor.client.plugins.contentnegotiation.ContentNegotiation -import io.ktor.client.plugins.logging.ANDROID -import io.ktor.client.plugins.logging.LogLevel -import io.ktor.client.plugins.logging.Logger -import io.ktor.client.plugins.logging.Logging import io.ktor.client.request.HttpRequestBuilder import io.ktor.client.request.header import io.ktor.client.request.post diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt index b7f4a17c..fefe11f8 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt @@ -30,7 +30,7 @@ class NoParameterFunction( description: String, val function: suspend () -> String, ) : FunctionDeclaration(name, description) { - override fun getParameters() = listOf() + override fun getParameters() = listOf>() } /** @@ -43,11 +43,11 @@ class NoParameterFunction( * @property function the function implementation */ @BetaGenAiAPI -class OneParameterFunction( +class OneParameterFunction( name: String, description: String, - val param: FunctionParameter, - val function: suspend (String) -> String, + val param: FunctionParameter, + val function: suspend (T) -> String, ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf(param) } @@ -63,12 +63,12 @@ class OneParameterFunction( * @property function the function implementation */ @BetaGenAiAPI -class TwoParameterFunction( +class TwoParameterFunction( name: String, description: String, - val param1: FunctionParameter, - val param2: FunctionParameter, - val function: suspend (String, String) -> String, + val param1: FunctionParameter, + val param2: FunctionParameter, + val function: suspend (T, U) -> String, ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf(param1, param2) } @@ -85,13 +85,13 @@ class TwoParameterFunction( * @property function the function implementation */ @BetaGenAiAPI -class ThreeParameterFunction( +class ThreeParameterFunction( name: String, description: String, - val param1: FunctionParameter, - val param2: FunctionParameter, - val param3: FunctionParameter, - val function: suspend (String, String, String) -> String, + val param1: FunctionParameter, + val param2: FunctionParameter, + val param3: FunctionParameter, + val function: suspend (T, U, V) -> String, ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf(param1, param2, param3) } @@ -109,14 +109,14 @@ class ThreeParameterFunction( * @property function the function implementation */ @BetaGenAiAPI -class FourParameterFunction( +class FourParameterFunction( name: String, description: String, - val param1: FunctionParameter, - val param2: FunctionParameter, - val param3: FunctionParameter, - val param4: FunctionParameter, - val function: suspend (String, String, String, String) -> String, + val param1: FunctionParameter, + val param2: FunctionParameter, + val param3: FunctionParameter, + val param4: FunctionParameter, + val function: suspend (T, U, V, W) -> String, ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf(param1, param2, param3, param4) } @@ -126,7 +126,7 @@ abstract class FunctionDeclaration( val name: String, val description: String, ) { - abstract fun getParameters(): List + abstract fun getParameters(): List> } /** @@ -142,69 +142,243 @@ class FunctionBuilder(private val name: String, private val description: String) return NoParameterFunction(name, description, function) } - fun param(param: FunctionParameter): OneFunctionBuilder { - return OneFunctionBuilder(name, description, param) + fun param(param: FunctionParameter): OneFunctionBuilder { + return OneFunctionBuilder(name, description, param) + } + + fun param( + paramName: String, + paramDescription: String, + type: FunctionType + ): OneFunctionBuilder { + return OneFunctionBuilder( + name, + description, + FunctionParameter(paramName, paramDescription, type) + ) + } + + fun stringParam(paramName: String, paramDescription: String): OneFunctionBuilder { + return OneFunctionBuilder( + name, + description, + FunctionParameter(paramName, paramDescription, FunctionType.STRING) + ) + } + + fun intParam(paramName: String, paramDescription: String): OneFunctionBuilder { + return OneFunctionBuilder( + name, + description, + FunctionParameter(paramName, paramDescription, FunctionType.INT) + ) + } + + fun boolParam(paramName: String, paramDescription: String): OneFunctionBuilder { + return OneFunctionBuilder( + name, + description, + FunctionParameter(paramName, paramDescription, FunctionType.BOOLEAN) + ) } } @BetaGenAiAPI -class OneFunctionBuilder( +class OneFunctionBuilder( private val name: String, private val description: String, - private val param1: FunctionParameter + private val param1: FunctionParameter ) { - fun build(function: suspend (String) -> String): FunctionDeclaration { + fun build(function: suspend (T) -> String): FunctionDeclaration { return OneParameterFunction(name, description, param1, function) } - fun param(param: FunctionParameter): TwoFunctionBuilder { + fun param(param: FunctionParameter): TwoFunctionBuilder { return TwoFunctionBuilder(name, description, param1, param) } + + fun param( + paramName: String, + paramDescription: String, + type: FunctionType + ): TwoFunctionBuilder { + return TwoFunctionBuilder( + name, + description, + param1, + FunctionParameter(paramName, paramDescription, type) + ) + } + + fun stringParam(paramName: String, paramDescription: String): TwoFunctionBuilder { + return TwoFunctionBuilder( + name, + description, + param1, + FunctionParameter(paramName, paramDescription, FunctionType.STRING) + ) + } + + fun intParam(paramName: String, paramDescription: String): TwoFunctionBuilder { + return TwoFunctionBuilder( + name, + description, + param1, + FunctionParameter(paramName, paramDescription, FunctionType.INT) + ) + } + + fun boolParam(paramName: String, paramDescription: String): TwoFunctionBuilder { + return TwoFunctionBuilder( + name, + description, + param1, + FunctionParameter(paramName, paramDescription, FunctionType.BOOLEAN) + ) + } } @BetaGenAiAPI -class TwoFunctionBuilder( +class TwoFunctionBuilder( private val name: String, private val description: String, - private val param1: FunctionParameter, - private val param2: FunctionParameter, + private val param1: FunctionParameter, + private val param2: FunctionParameter, ) { - fun build(function: suspend (String, String) -> String): FunctionDeclaration { + fun build(function: suspend (T, U) -> String): FunctionDeclaration { return TwoParameterFunction(name, description, param1, param2, function) } - fun param(param: FunctionParameter): ThreeFunctionBuilder { + fun param(param: FunctionParameter): ThreeFunctionBuilder { return ThreeFunctionBuilder(name, description, param1, param2, param) } + + fun param( + paramName: String, + paramDescription: String, + type: FunctionType + ): ThreeFunctionBuilder { + return ThreeFunctionBuilder( + name, + description, + param1, + param2, + FunctionParameter(paramName, paramDescription, type) + ) + } + + fun stringParam(paramName: String, paramDescription: String): ThreeFunctionBuilder { + return ThreeFunctionBuilder( + name, + description, + param1, + param2, + FunctionParameter(paramName, paramDescription, FunctionType.STRING) + ) + } + + fun intParam(paramName: String, paramDescription: String): ThreeFunctionBuilder { + return ThreeFunctionBuilder( + name, + description, + param1, + param2, + FunctionParameter(paramName, paramDescription, FunctionType.INT) + ) + } + + fun boolParam(paramName: String, paramDescription: String): ThreeFunctionBuilder { + return ThreeFunctionBuilder( + name, + description, + param1, + param2, + FunctionParameter(paramName, paramDescription, FunctionType.BOOLEAN) + ) + } } @BetaGenAiAPI -class ThreeFunctionBuilder( +class ThreeFunctionBuilder( private val name: String, private val description: String, - private val param1: FunctionParameter, - private val param2: FunctionParameter, - private val param3: FunctionParameter, + private val param1: FunctionParameter, + private val param2: FunctionParameter, + private val param3: FunctionParameter, ) { - fun build(function: suspend (String, String, String) -> String): FunctionDeclaration { + fun build(function: suspend (T, U, V) -> String): FunctionDeclaration { return ThreeParameterFunction(name, description, param1, param2, param3, function) } - fun param(param: FunctionParameter): FourFunctionBuilder { + fun param(param: FunctionParameter): FourFunctionBuilder { return FourFunctionBuilder(name, description, param1, param2, param3, param) } + + fun param( + paramName: String, + paramDescription: String, + type: FunctionType + ): FourFunctionBuilder { + return FourFunctionBuilder( + name, + description, + param1, + param2, + param3, + FunctionParameter(paramName, paramDescription, type) + ) + } + + fun stringParam( + paramName: String, + paramDescription: String + ): FourFunctionBuilder { + return FourFunctionBuilder( + name, + description, + param1, + param2, + param3, + FunctionParameter(paramName, paramDescription, FunctionType.STRING) + ) + } + + fun intParam(paramName: String, paramDescription: String): FourFunctionBuilder { + return FourFunctionBuilder( + name, + description, + param1, + param2, + param3, + FunctionParameter(paramName, paramDescription, FunctionType.INT) + ) + } + + fun boolParam( + paramName: String, + paramDescription: String + ): FourFunctionBuilder { + return FourFunctionBuilder( + name, + description, + param1, + param2, + param3, + FunctionParameter(paramName, paramDescription, FunctionType.BOOLEAN) + ) + } } @BetaGenAiAPI -class FourFunctionBuilder( +class FourFunctionBuilder( private val name: String, private val description: String, - private val param1: FunctionParameter, - private val param2: FunctionParameter, - private val param3: FunctionParameter, - private val param4: FunctionParameter, + private val param1: FunctionParameter, + private val param2: FunctionParameter, + private val param3: FunctionParameter, + private val param4: FunctionParameter, ) { - fun build(function: suspend (String, String, String, String) -> String): FunctionDeclaration { + fun build(function: suspend (T, U, V, W) -> String): FunctionDeclaration { return FourParameterFunction(name, description, param1, param2, param3, param4, function) } } diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt index 7ae83a65..09d9fcf3 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt @@ -16,4 +16,4 @@ package com.google.ai.client.generativeai.type -class FunctionParameter(val name: String, val description: String, val type: String) {} +class FunctionParameter(val name: String, val description: String, val type: FunctionType) {} diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerateContentResponse.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerateContentResponse.kt index 9e33ce16..9fe30329 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerateContentResponse.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerateContentResponse.kt @@ -32,6 +32,12 @@ class GenerateContentResponse( /** Convenience field representing the first text part in the response, if it exists. */ val text: String? by lazy { firstPartAs()?.text } + /** Convenience field representing the first text part in the response, if it exists. */ + val functionCall: FunctionCallPart? by lazy { firstPartAs() } + + /** Convenience field representing the first text part in the response, if it exists. */ + val functionResponse: FunctionResponsePart? by lazy { firstPartAs() } + private inline fun firstPartAs(): T? { if (candidates.isEmpty()) { warn("No candidates were found, but was asked to get a candidate.") diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Type.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Type.kt new file mode 100644 index 00000000..53f0dce2 --- /dev/null +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Type.kt @@ -0,0 +1,9 @@ +package com.google.ai.client.generativeai.type + +class FunctionType(val name: String, val parse: (String?) -> T?) { + companion object { + val STRING = FunctionType("STRING") { it } + val INT = FunctionType("INTEGER") { it?.toIntOrNull() } + val BOOLEAN = FunctionType("BOOLEAN") { it?.toBoolean() } + } +} From f8d60e3c7daaedbc196d1f364503ec57cec25048 Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Thu, 22 Feb 2024 17:37:16 +0000 Subject: [PATCH 08/36] Alternative implementation (#68) Co-authored-by: Rodrigo Lazo Paz --- .../ai/client/generativeai/GenerativeModel.kt | 71 +--- .../generativeai/type/FunctionDeclarations.kt | 354 +++++------------- .../generativeai/GenerativeModelTests.kt | 43 ++- 3 files changed, 158 insertions(+), 310 deletions(-) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt index 72d5e6c3..ea0cefc2 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt @@ -28,12 +28,12 @@ import com.google.ai.client.generativeai.type.CountTokensResponse import com.google.ai.client.generativeai.type.FinishReason import com.google.ai.client.generativeai.type.FourParameterFunction import com.google.ai.client.generativeai.type.FunctionCallPart -import com.google.ai.client.generativeai.type.FunctionParameter import com.google.ai.client.generativeai.type.GenerateContentResponse import com.google.ai.client.generativeai.type.GenerationConfig import com.google.ai.client.generativeai.type.GoogleGenerativeAIException import com.google.ai.client.generativeai.type.NoParameterFunction import com.google.ai.client.generativeai.type.OneParameterFunction +import com.google.ai.client.generativeai.type.ParameterDeclaration import com.google.ai.client.generativeai.type.PromptBlockedException import com.google.ai.client.generativeai.type.ResponseStoppedException import com.google.ai.client.generativeai.type.SafetySetting @@ -182,71 +182,34 @@ internal constructor( /** * Executes a function request by the model. * - * @param call A [FunctionCallPart] from the model, containing a function call and parameters + * @param functionCallPart A [FunctionCallPart] from the model, containing a function call and + * parameters * @return The output of the requested function call */ @BetaGenAiAPI - suspend fun executeFunction(call: FunctionCallPart): String { + fun executeFunction(functionCallPart: FunctionCallPart): String { if (tools == null) { throw RuntimeException("No registered tools") } - val tool = tools.first { it.functionDeclarations.any { it.name == call.name } } - val declaration = - tool.functionDeclarations.firstOrNull() { it.name == call.name } - ?: throw RuntimeException("No registered function named ${call.name}") - return when (declaration) { - is NoParameterFunction -> { - declaration.function.invoke() - } - is OneParameterFunction<*> -> { - declaration - .let { declaration as OneParameterFunction } - .let { - val param1 = getParamOrThrow(it.param, call) - it.function.invoke(param1) - } - } - is TwoParameterFunction<*, *> -> { - declaration - .let { declaration as TwoParameterFunction } - .let { - val param1 = getParamOrThrow(it.param1, call) - val param2 = getParamOrThrow(it.param2, call) - it.function.invoke(param1, param2) - } - } - is ThreeParameterFunction<*, *, *> -> { - declaration - .let { declaration as ThreeParameterFunction } - .let { - val param1 = getParamOrThrow(it.param1, call) - val param2 = getParamOrThrow(it.param2, call) - val param3 = getParamOrThrow(it.param3, call) - it.function.invoke(param1, param2, param3) - } - } - is FourParameterFunction<*, *, *, *> -> { - declaration - .let { declaration as FourParameterFunction } - .let { - val param1 = getParamOrThrow(it.param1, call) - val param2 = getParamOrThrow(it.param2, call) - val param3 = getParamOrThrow(it.param3, call) - val param4 = getParamOrThrow(it.param4, call) - it.function.invoke(param1, param2, param3, param4) - } - } + val tool = tools.first { it.functionDeclarations.any { it.name == functionCallPart.name } } + val callable = + tool.functionDeclarations.firstOrNull() { it.name == functionCallPart.name } + ?: throw RuntimeException("No registered function named ${functionCallPart.name}") + return when (callable) { + is NoParameterFunction -> callable() + is OneParameterFunction<*> -> (callable as OneParameterFunction)(functionCallPart) + is TwoParameterFunction<*, *> -> + (callable as TwoParameterFunction)(functionCallPart) + is ThreeParameterFunction<*, *, *> -> + (callable as ThreeParameterFunction)(functionCallPart) + is FourParameterFunction<*, *, *, *> -> + (callable as FourParameterFunction)(functionCallPart) else -> { throw RuntimeException("UNREACHABLE") } } } - private fun getParamOrThrow(param: FunctionParameter, part: FunctionCallPart): T { - return param.type.parse.invoke(part.args[param.name]) - ?: throw RuntimeException("Missing parameter named ${param.name} for function ${part.name}") - } - private fun constructRequest(vararg prompt: Content) = GenerateContentRequest( modelName, diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt index fefe11f8..ada82d3f 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt @@ -28,9 +28,13 @@ package com.google.ai.client.generativeai.type class NoParameterFunction( name: String, description: String, - val function: suspend () -> String, + val function: () -> String, ) : FunctionDeclaration(name, description) { - override fun getParameters() = listOf>() + override fun getParameters() = listOf>() + + operator fun invoke() = function() + + operator fun invoke(part: FunctionCallPart) = invoke() } /** @@ -46,10 +50,15 @@ class NoParameterFunction( class OneParameterFunction( name: String, description: String, - val param: FunctionParameter, - val function: suspend (T) -> String, + val param: ParameterDeclaration, + val function: (T) -> String, ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf(param) + + operator fun invoke(part: FunctionCallPart): String { + val arg1 = part.getArgOrThrow(param) + return function(arg1) + } } /** @@ -66,11 +75,17 @@ class OneParameterFunction( class TwoParameterFunction( name: String, description: String, - val param1: FunctionParameter, - val param2: FunctionParameter, - val function: suspend (T, U) -> String, + val param1: ParameterDeclaration, + val param2: ParameterDeclaration, + val function: (T, U) -> String, ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf(param1, param2) + + operator fun invoke(part: FunctionCallPart): String { + val arg1 = part.getArgOrThrow(param1) + val arg2 = part.getArgOrThrow(param2) + return function(arg1, arg2) + } } /** @@ -88,12 +103,19 @@ class TwoParameterFunction( class ThreeParameterFunction( name: String, description: String, - val param1: FunctionParameter, - val param2: FunctionParameter, - val param3: FunctionParameter, - val function: suspend (T, U, V) -> String, + val param1: ParameterDeclaration, + val param2: ParameterDeclaration, + val param3: ParameterDeclaration, + val function: (T, U, V) -> String, ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf(param1, param2, param3) + + operator fun invoke(part: FunctionCallPart): String { + val arg1 = part.getArgOrThrow(param1) + val arg2 = part.getArgOrThrow(param2) + val arg3 = part.getArgOrThrow(param3) + return function(arg1, arg2, arg3) + } } /** @@ -112,13 +134,21 @@ class ThreeParameterFunction( class FourParameterFunction( name: String, description: String, - val param1: FunctionParameter, - val param2: FunctionParameter, - val param3: FunctionParameter, - val param4: FunctionParameter, - val function: suspend (T, U, V, W) -> String, + val param1: ParameterDeclaration, + val param2: ParameterDeclaration, + val param3: ParameterDeclaration, + val param4: ParameterDeclaration, + val function: (T, U, V, W) -> String, ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf(param1, param2, param3, param4) + + operator fun invoke(part: FunctionCallPart): String { + val arg1 = part.getArgOrThrow(param1) + val arg2 = part.getArgOrThrow(param2) + val arg3 = part.getArgOrThrow(param3) + val arg4 = part.getArgOrThrow(param4) + return function(arg1, arg2, arg3, arg4) + } } @BetaGenAiAPI @@ -126,259 +156,73 @@ abstract class FunctionDeclaration( val name: String, val description: String, ) { - abstract fun getParameters(): List> + abstract fun getParameters(): List> } -/** - * A builder to help build [FunctionDeclaration] objects - * - * @property name The name of the function call, this should be clear and descriptive for the model - * @property description A description of what the function does and its output. - */ -@BetaGenAiAPI -class FunctionBuilder(private val name: String, private val description: String) { - - fun build(function: suspend () -> String): FunctionDeclaration { - return NoParameterFunction(name, description, function) - } - - fun param(param: FunctionParameter): OneFunctionBuilder { - return OneFunctionBuilder(name, description, param) - } - - fun param( - paramName: String, - paramDescription: String, - type: FunctionType - ): OneFunctionBuilder { - return OneFunctionBuilder( - name, - description, - FunctionParameter(paramName, paramDescription, type) - ) - } - - fun stringParam(paramName: String, paramDescription: String): OneFunctionBuilder { - return OneFunctionBuilder( - name, - description, - FunctionParameter(paramName, paramDescription, FunctionType.STRING) - ) - } - - fun intParam(paramName: String, paramDescription: String): OneFunctionBuilder { - return OneFunctionBuilder( - name, - description, - FunctionParameter(paramName, paramDescription, FunctionType.INT) - ) - } - - fun boolParam(paramName: String, paramDescription: String): OneFunctionBuilder { - return OneFunctionBuilder( - name, - description, - FunctionParameter(paramName, paramDescription, FunctionType.BOOLEAN) - ) - } -} - -@BetaGenAiAPI -class OneFunctionBuilder( - private val name: String, - private val description: String, - private val param1: FunctionParameter +class ParameterDeclaration( + val name: String, + val description: String, + private val type: FunctionType ) { - fun build(function: suspend (T) -> String): FunctionDeclaration { - return OneParameterFunction(name, description, param1, function) - } + fun fromString(value: String?) = type.parse(value) - fun param(param: FunctionParameter): TwoFunctionBuilder { - return TwoFunctionBuilder(name, description, param1, param) - } + companion object { + fun int(name: String, description: String) = + ParameterDeclaration(name, description, FunctionType.INT) - fun param( - paramName: String, - paramDescription: String, - type: FunctionType - ): TwoFunctionBuilder { - return TwoFunctionBuilder( - name, - description, - param1, - FunctionParameter(paramName, paramDescription, type) - ) - } + fun string(name: String, description: String) = + ParameterDeclaration(name, description, FunctionType.STRING) - fun stringParam(paramName: String, paramDescription: String): TwoFunctionBuilder { - return TwoFunctionBuilder( - name, - description, - param1, - FunctionParameter(paramName, paramDescription, FunctionType.STRING) - ) - } - - fun intParam(paramName: String, paramDescription: String): TwoFunctionBuilder { - return TwoFunctionBuilder( - name, - description, - param1, - FunctionParameter(paramName, paramDescription, FunctionType.INT) - ) - } - - fun boolParam(paramName: String, paramDescription: String): TwoFunctionBuilder { - return TwoFunctionBuilder( - name, - description, - param1, - FunctionParameter(paramName, paramDescription, FunctionType.BOOLEAN) - ) + fun boolean(name: String, description: String) = + ParameterDeclaration(name, description, FunctionType.BOOLEAN) } } @BetaGenAiAPI -class TwoFunctionBuilder( - private val name: String, - private val description: String, - private val param1: FunctionParameter, - private val param2: FunctionParameter, -) { - fun build(function: suspend (T, U) -> String): FunctionDeclaration { - return TwoParameterFunction(name, description, param1, param2, function) - } - - fun param(param: FunctionParameter): ThreeFunctionBuilder { - return ThreeFunctionBuilder(name, description, param1, param2, param) - } - - fun param( - paramName: String, - paramDescription: String, - type: FunctionType - ): ThreeFunctionBuilder { - return ThreeFunctionBuilder( - name, - description, - param1, - param2, - FunctionParameter(paramName, paramDescription, type) - ) - } - - fun stringParam(paramName: String, paramDescription: String): ThreeFunctionBuilder { - return ThreeFunctionBuilder( - name, - description, - param1, - param2, - FunctionParameter(paramName, paramDescription, FunctionType.STRING) - ) - } - - fun intParam(paramName: String, paramDescription: String): ThreeFunctionBuilder { - return ThreeFunctionBuilder( - name, - description, - param1, - param2, - FunctionParameter(paramName, paramDescription, FunctionType.INT) - ) - } - - fun boolParam(paramName: String, paramDescription: String): ThreeFunctionBuilder { - return ThreeFunctionBuilder( - name, - description, - param1, - param2, - FunctionParameter(paramName, paramDescription, FunctionType.BOOLEAN) - ) - } -} +fun defineFunction(name: String, description: String, function: () -> String) = + NoParameterFunction(name, description, function) @BetaGenAiAPI -class ThreeFunctionBuilder( - private val name: String, - private val description: String, - private val param1: FunctionParameter, - private val param2: FunctionParameter, - private val param3: FunctionParameter, -) { - fun build(function: suspend (T, U, V) -> String): FunctionDeclaration { - return ThreeParameterFunction(name, description, param1, param2, param3, function) - } - - fun param(param: FunctionParameter): FourFunctionBuilder { - return FourFunctionBuilder(name, description, param1, param2, param3, param) - } - - fun param( - paramName: String, - paramDescription: String, - type: FunctionType - ): FourFunctionBuilder { - return FourFunctionBuilder( - name, - description, - param1, - param2, - param3, - FunctionParameter(paramName, paramDescription, type) - ) - } - - fun stringParam( - paramName: String, - paramDescription: String - ): FourFunctionBuilder { - return FourFunctionBuilder( - name, - description, - param1, - param2, - param3, - FunctionParameter(paramName, paramDescription, FunctionType.STRING) - ) - } +fun defineFunction( + name: String, + description: String, + arg1: ParameterDeclaration, + function: (T) -> String +) = OneParameterFunction(name, description, arg1, function) - fun intParam(paramName: String, paramDescription: String): FourFunctionBuilder { - return FourFunctionBuilder( - name, - description, - param1, - param2, - param3, - FunctionParameter(paramName, paramDescription, FunctionType.INT) - ) - } +@BetaGenAiAPI +fun defineFunction( + name: String, + description: String, + arg1: ParameterDeclaration, + arg2: ParameterDeclaration, + function: (T, U) -> String +) = TwoParameterFunction(name, description, arg1, arg2, function) - fun boolParam( - paramName: String, - paramDescription: String - ): FourFunctionBuilder { - return FourFunctionBuilder( - name, - description, - param1, - param2, - param3, - FunctionParameter(paramName, paramDescription, FunctionType.BOOLEAN) - ) - } -} +@BetaGenAiAPI +fun defineFunction( + name: String, + description: String, + arg1: ParameterDeclaration, + arg2: ParameterDeclaration, + arg3: ParameterDeclaration, + function: (T, U, W) -> String +) = ThreeParameterFunction(name, description, arg1, arg2, arg3, function) @BetaGenAiAPI -class FourFunctionBuilder( - private val name: String, - private val description: String, - private val param1: FunctionParameter, - private val param2: FunctionParameter, - private val param3: FunctionParameter, - private val param4: FunctionParameter, -) { - fun build(function: suspend (T, U, V, W) -> String): FunctionDeclaration { - return FourParameterFunction(name, description, param1, param2, param3, param4, function) - } +fun defineFunction( + name: String, + description: String, + arg1: ParameterDeclaration, + arg2: ParameterDeclaration, + arg3: ParameterDeclaration, + arg4: ParameterDeclaration, + function: (T, U, W, Z) -> String +) = FourParameterFunction(name, description, arg1, arg2, arg3, arg4, function) + +private fun FunctionCallPart.getArgOrThrow(param: ParameterDeclaration): T { + return param.fromString(args[param.name]) + ?: throw RuntimeException( + "Missing argument for parameter \"${param.name}\" for function \"$name\"" + ) } diff --git a/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt b/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt index da0e90b6..c3920e9f 100644 --- a/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt +++ b/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt @@ -16,6 +16,11 @@ package com.google.ai.client.generativeai +import com.google.ai.client.generativeai.type.BetaGenAiAPI +import com.google.ai.client.generativeai.type.FunctionCallPart +import com.google.ai.client.generativeai.type.ParameterDeclaration +import com.google.ai.client.generativeai.type.TwoParameterFunction +import com.google.ai.client.generativeai.type.defineFunction import com.google.ai.client.generativeai.util.commonTest import com.google.ai.client.generativeai.util.createResponses import com.google.ai.client.generativeai.util.prepareStreamingResponse @@ -23,7 +28,7 @@ import io.kotest.matchers.shouldBe import io.ktor.utils.io.close import io.ktor.utils.io.writeFully import kotlin.time.Duration.Companion.seconds -import kotlinx.coroutines.flow.collect +import kotlinx.coroutines.runBlocking import kotlinx.coroutines.withTimeout import org.junit.Test @@ -45,4 +50,40 @@ internal class GenerativeModelTests { } } } + + // + // + // FOR DEV PURPOSES ONLY + // + // + + fun myfun(a: Int, b: Int): String { + return (a + b).toString() + } + + @OptIn(BetaGenAiAPI::class) + @Test + fun `calling test`(): Unit = runBlocking { + // val f = + // FunctionBuilder("sum", "add two numbers together") + // .intParam("a", "First number to add together") + // .intParam("b", "Second number to add together") + // .build(::myfun) + + val f = + defineFunction( + "sum", + "add two numbers together", + ParameterDeclaration.int("a", "First number to add together"), + ParameterDeclaration.int("b", "Second number to add together") + ) { a, b -> + (a + b).toString() + } + + val x = f as TwoParameterFunction + val p = FunctionCallPart("sum", mapOf("a" to "2", "b" to "3")) + val q = x(p) + + q shouldBe "5" + } } From b09c6ef15bd962f4b39820039b12e4723a1b3935 Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Thu, 22 Feb 2024 09:45:47 -0800 Subject: [PATCH 09/36] rename annotation and cleanup rlazo@'s merge --- .../com/google/ai/client/generativeai/Chat.kt | 4 +- .../ai/client/generativeai/GenerativeModel.kt | 7 ++- .../generativeai/internal/util/conversions.kt | 6 +-- .../generativeai/type/FunctionDeclarations.kt | 22 ++++----- .../{BetaGenAiAPI.kt => GenerativeBeta.kt} | 2 +- .../ai/client/generativeai/type/Tool.kt | 2 +- .../generativeai/GenerativeModelTests.kt | 46 +------------------ 7 files changed, 23 insertions(+), 66 deletions(-) rename generativeai/src/main/java/com/google/ai/client/generativeai/type/{BetaGenAiAPI.kt => GenerativeBeta.kt} (96%) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt index be304125..2f586ffa 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt @@ -17,7 +17,7 @@ package com.google.ai.client.generativeai import android.graphics.Bitmap -import com.google.ai.client.generativeai.type.BetaGenAiAPI +import com.google.ai.client.generativeai.type.GenerativeBeta import com.google.ai.client.generativeai.type.BlobPart import com.google.ai.client.generativeai.type.Content import com.google.ai.client.generativeai.type.FunctionCallPart @@ -46,7 +46,7 @@ import kotlinx.coroutines.flow.transform * @param model the model to use for the interaction * @property history the previous interactions with the model */ -@OptIn(BetaGenAiAPI::class) +@OptIn(GenerativeBeta::class) class Chat(private val model: GenerativeModel, val history: MutableList = ArrayList()) { private var lock = Semaphore(1) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt index ea0cefc2..ebcf4857 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt @@ -22,7 +22,7 @@ import com.google.ai.client.generativeai.internal.api.CountTokensRequest import com.google.ai.client.generativeai.internal.api.GenerateContentRequest import com.google.ai.client.generativeai.internal.util.toInternal import com.google.ai.client.generativeai.internal.util.toPublic -import com.google.ai.client.generativeai.type.BetaGenAiAPI +import com.google.ai.client.generativeai.type.GenerativeBeta import com.google.ai.client.generativeai.type.Content import com.google.ai.client.generativeai.type.CountTokensResponse import com.google.ai.client.generativeai.type.FinishReason @@ -33,7 +33,6 @@ import com.google.ai.client.generativeai.type.GenerationConfig import com.google.ai.client.generativeai.type.GoogleGenerativeAIException import com.google.ai.client.generativeai.type.NoParameterFunction import com.google.ai.client.generativeai.type.OneParameterFunction -import com.google.ai.client.generativeai.type.ParameterDeclaration import com.google.ai.client.generativeai.type.PromptBlockedException import com.google.ai.client.generativeai.type.ResponseStoppedException import com.google.ai.client.generativeai.type.SafetySetting @@ -55,7 +54,7 @@ import kotlinx.coroutines.flow.map * @property safetySettings the safety bounds to use during alongside prompts during content * generation */ -@OptIn(BetaGenAiAPI::class) +@OptIn(GenerativeBeta::class) class GenerativeModel internal constructor( val modelName: String, @@ -186,7 +185,7 @@ internal constructor( * parameters * @return The output of the requested function call */ - @BetaGenAiAPI + @GenerativeBeta fun executeFunction(functionCallPart: FunctionCallPart): String { if (tools == null) { throw RuntimeException("No registered tools") diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt index 010ca807..fc730cd8 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt @@ -42,7 +42,7 @@ import com.google.ai.client.generativeai.internal.api.shared.HarmCategory import com.google.ai.client.generativeai.internal.api.shared.Part import com.google.ai.client.generativeai.internal.api.shared.SafetySetting import com.google.ai.client.generativeai.internal.api.shared.TextPart -import com.google.ai.client.generativeai.type.BetaGenAiAPI +import com.google.ai.client.generativeai.type.GenerativeBeta import com.google.ai.client.generativeai.type.BlockThreshold import com.google.ai.client.generativeai.type.CitationMetadata import com.google.ai.client.generativeai.type.FunctionDeclaration @@ -109,13 +109,13 @@ internal fun BlockThreshold.toInternal() = BlockThreshold.UNSPECIFIED -> HarmBlockThreshold.UNSPECIFIED } -@BetaGenAiAPI +@GenerativeBeta internal fun Tool.toInternal() = com.google.ai.client.generativeai.internal.api.client.Tool( functionDeclarations.map { it.toInternal() } ) -@BetaGenAiAPI +@GenerativeBeta internal fun FunctionDeclaration.toInternal(): com.google.ai.client.generativeai.internal.api.client.FunctionDeclaration { val convertedParams = buildJsonObject { diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt index ada82d3f..eea24027 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt @@ -24,7 +24,7 @@ package com.google.ai.client.generativeai.type * @property description A description of what the function does and its output. * @property function the function implementation */ -@BetaGenAiAPI +@GenerativeBeta class NoParameterFunction( name: String, description: String, @@ -46,7 +46,7 @@ class NoParameterFunction( * @property param A description of the first function parameter * @property function the function implementation */ -@BetaGenAiAPI +@GenerativeBeta class OneParameterFunction( name: String, description: String, @@ -71,7 +71,7 @@ class OneParameterFunction( * @property param2 A description of the second function parameter * @property function the function implementation */ -@BetaGenAiAPI +@GenerativeBeta class TwoParameterFunction( name: String, description: String, @@ -99,7 +99,7 @@ class TwoParameterFunction( * @property param3 A description of the third function parameter * @property function the function implementation */ -@BetaGenAiAPI +@GenerativeBeta class ThreeParameterFunction( name: String, description: String, @@ -130,7 +130,7 @@ class ThreeParameterFunction( * @property param4 A description of the fourth function parameter * @property function the function implementation */ -@BetaGenAiAPI +@GenerativeBeta class FourParameterFunction( name: String, description: String, @@ -151,7 +151,7 @@ class FourParameterFunction( } } -@BetaGenAiAPI +@GenerativeBeta abstract class FunctionDeclaration( val name: String, val description: String, @@ -178,11 +178,11 @@ class ParameterDeclaration( } } -@BetaGenAiAPI +@GenerativeBeta fun defineFunction(name: String, description: String, function: () -> String) = NoParameterFunction(name, description, function) -@BetaGenAiAPI +@GenerativeBeta fun defineFunction( name: String, description: String, @@ -190,7 +190,7 @@ fun defineFunction( function: (T) -> String ) = OneParameterFunction(name, description, arg1, function) -@BetaGenAiAPI +@GenerativeBeta fun defineFunction( name: String, description: String, @@ -199,7 +199,7 @@ fun defineFunction( function: (T, U) -> String ) = TwoParameterFunction(name, description, arg1, arg2, function) -@BetaGenAiAPI +@GenerativeBeta fun defineFunction( name: String, description: String, @@ -209,7 +209,7 @@ fun defineFunction( function: (T, U, W) -> String ) = ThreeParameterFunction(name, description, arg1, arg2, arg3, function) -@BetaGenAiAPI +@GenerativeBeta fun defineFunction( name: String, description: String, diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/BetaGenAiAPI.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerativeBeta.kt similarity index 96% rename from generativeai/src/main/java/com/google/ai/client/generativeai/type/BetaGenAiAPI.kt rename to generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerativeBeta.kt index 42bc9beb..096c144e 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/BetaGenAiAPI.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerativeBeta.kt @@ -19,4 +19,4 @@ package com.google.ai.client.generativeai.type @RequiresOptIn(message = "This API is only available on the v1beta endpoint") @Retention(AnnotationRetention.BINARY) @Target(AnnotationTarget.CLASS, AnnotationTarget.FUNCTION) -annotation class BetaGenAiAPI +annotation class GenerativeBeta diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt index 404a278b..9e3b2d1e 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt @@ -22,7 +22,7 @@ package com.google.ai.client.generativeai.type * * @param functionDeclarations The set of functions that this tool allows the model access to */ -@OptIn(BetaGenAiAPI::class) +@OptIn(GenerativeBeta::class) class Tool( val functionDeclarations: List, ) diff --git a/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt b/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt index c3920e9f..2f0150bd 100644 --- a/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt +++ b/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt @@ -16,21 +16,15 @@ package com.google.ai.client.generativeai -import com.google.ai.client.generativeai.type.BetaGenAiAPI -import com.google.ai.client.generativeai.type.FunctionCallPart -import com.google.ai.client.generativeai.type.ParameterDeclaration -import com.google.ai.client.generativeai.type.TwoParameterFunction -import com.google.ai.client.generativeai.type.defineFunction import com.google.ai.client.generativeai.util.commonTest import com.google.ai.client.generativeai.util.createResponses import com.google.ai.client.generativeai.util.prepareStreamingResponse import io.kotest.matchers.shouldBe import io.ktor.utils.io.close import io.ktor.utils.io.writeFully -import kotlin.time.Duration.Companion.seconds -import kotlinx.coroutines.runBlocking import kotlinx.coroutines.withTimeout import org.junit.Test +import kotlin.time.Duration.Companion.seconds internal class GenerativeModelTests { private val testTimeout = 5.seconds @@ -50,40 +44,4 @@ internal class GenerativeModelTests { } } } - - // - // - // FOR DEV PURPOSES ONLY - // - // - - fun myfun(a: Int, b: Int): String { - return (a + b).toString() - } - - @OptIn(BetaGenAiAPI::class) - @Test - fun `calling test`(): Unit = runBlocking { - // val f = - // FunctionBuilder("sum", "add two numbers together") - // .intParam("a", "First number to add together") - // .intParam("b", "Second number to add together") - // .build(::myfun) - - val f = - defineFunction( - "sum", - "add two numbers together", - ParameterDeclaration.int("a", "First number to add together"), - ParameterDeclaration.int("b", "Second number to add together") - ) { a, b -> - (a + b).toString() - } - - val x = f as TwoParameterFunction - val p = FunctionCallPart("sum", mapOf("a" to "2", "b" to "3")) - val q = x(p) - - q shouldBe "5" - } -} +} \ No newline at end of file From 08309a7c09d18eade26799958d72533bfca7ddce Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Thu, 22 Feb 2024 10:00:11 -0800 Subject: [PATCH 10/36] re-add suspend functions --- .../ai/client/generativeai/GenerativeModel.kt | 12 +++---- .../generativeai/type/FunctionDeclarations.kt | 33 ++++++++++--------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt index ebcf4857..7b18eabf 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt @@ -186,7 +186,7 @@ internal constructor( * @return The output of the requested function call */ @GenerativeBeta - fun executeFunction(functionCallPart: FunctionCallPart): String { + suspend fun executeFunction(functionCallPart: FunctionCallPart): String { if (tools == null) { throw RuntimeException("No registered tools") } @@ -195,14 +195,14 @@ internal constructor( tool.functionDeclarations.firstOrNull() { it.name == functionCallPart.name } ?: throw RuntimeException("No registered function named ${functionCallPart.name}") return when (callable) { - is NoParameterFunction -> callable() - is OneParameterFunction<*> -> (callable as OneParameterFunction)(functionCallPart) + is NoParameterFunction -> callable.execute() + is OneParameterFunction<*> -> (callable as OneParameterFunction).execute(functionCallPart) is TwoParameterFunction<*, *> -> - (callable as TwoParameterFunction)(functionCallPart) + (callable as TwoParameterFunction).execute(functionCallPart) is ThreeParameterFunction<*, *, *> -> - (callable as ThreeParameterFunction)(functionCallPart) + (callable as ThreeParameterFunction).execute(functionCallPart) is FourParameterFunction<*, *, *, *> -> - (callable as FourParameterFunction)(functionCallPart) + (callable as FourParameterFunction).execute(functionCallPart) else -> { throw RuntimeException("UNREACHABLE") } diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt index eea24027..6b5e073d 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt @@ -28,13 +28,13 @@ package com.google.ai.client.generativeai.type class NoParameterFunction( name: String, description: String, - val function: () -> String, + val function: suspend () -> String, ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf>() - operator fun invoke() = function() + suspend fun execute() = function() - operator fun invoke(part: FunctionCallPart) = invoke() + override suspend fun execute(part: FunctionCallPart) = function() } /** @@ -51,11 +51,11 @@ class OneParameterFunction( name: String, description: String, val param: ParameterDeclaration, - val function: (T) -> String, + val function: suspend (T) -> String, ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf(param) - operator fun invoke(part: FunctionCallPart): String { + override suspend fun execute(part: FunctionCallPart): String { val arg1 = part.getArgOrThrow(param) return function(arg1) } @@ -77,11 +77,11 @@ class TwoParameterFunction( description: String, val param1: ParameterDeclaration, val param2: ParameterDeclaration, - val function: (T, U) -> String, + val function: suspend (T, U) -> String, ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf(param1, param2) - operator fun invoke(part: FunctionCallPart): String { + override suspend fun execute(part: FunctionCallPart): String { val arg1 = part.getArgOrThrow(param1) val arg2 = part.getArgOrThrow(param2) return function(arg1, arg2) @@ -106,11 +106,11 @@ class ThreeParameterFunction( val param1: ParameterDeclaration, val param2: ParameterDeclaration, val param3: ParameterDeclaration, - val function: (T, U, V) -> String, + val function: suspend (T, U, V) -> String, ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf(param1, param2, param3) - operator fun invoke(part: FunctionCallPart): String { + override suspend fun execute(part: FunctionCallPart): String { val arg1 = part.getArgOrThrow(param1) val arg2 = part.getArgOrThrow(param2) val arg3 = part.getArgOrThrow(param3) @@ -138,11 +138,11 @@ class FourParameterFunction( val param2: ParameterDeclaration, val param3: ParameterDeclaration, val param4: ParameterDeclaration, - val function: (T, U, V, W) -> String, + val function: suspend (T, U, V, W) -> String, ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf(param1, param2, param3, param4) - operator fun invoke(part: FunctionCallPart): String { + override suspend fun execute(part: FunctionCallPart): String { val arg1 = part.getArgOrThrow(param1) val arg2 = part.getArgOrThrow(param2) val arg3 = part.getArgOrThrow(param3) @@ -157,6 +157,7 @@ abstract class FunctionDeclaration( val description: String, ) { abstract fun getParameters(): List> + abstract suspend fun execute(part: FunctionCallPart): String } class ParameterDeclaration( @@ -179,7 +180,7 @@ class ParameterDeclaration( } @GenerativeBeta -fun defineFunction(name: String, description: String, function: () -> String) = +fun defineFunction(name: String, description: String, function: suspend () -> String) = NoParameterFunction(name, description, function) @GenerativeBeta @@ -187,7 +188,7 @@ fun defineFunction( name: String, description: String, arg1: ParameterDeclaration, - function: (T) -> String + function: suspend (T) -> String ) = OneParameterFunction(name, description, arg1, function) @GenerativeBeta @@ -196,7 +197,7 @@ fun defineFunction( description: String, arg1: ParameterDeclaration, arg2: ParameterDeclaration, - function: (T, U) -> String + function: suspend (T, U) -> String ) = TwoParameterFunction(name, description, arg1, arg2, function) @GenerativeBeta @@ -206,7 +207,7 @@ fun defineFunction( arg1: ParameterDeclaration, arg2: ParameterDeclaration, arg3: ParameterDeclaration, - function: (T, U, W) -> String + function: suspend (T, U, W) -> String ) = ThreeParameterFunction(name, description, arg1, arg2, arg3, function) @GenerativeBeta @@ -217,7 +218,7 @@ fun defineFunction( arg2: ParameterDeclaration, arg3: ParameterDeclaration, arg4: ParameterDeclaration, - function: (T, U, W, Z) -> String + function: suspend (T, U, W, Z) -> String ) = FourParameterFunction(name, description, arg1, arg2, arg3, arg4, function) private fun FunctionCallPart.getArgOrThrow(param: ParameterDeclaration): T { From 0516150808625d383c9a977309815ac7e8a8515f Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Thu, 22 Feb 2024 10:12:26 -0800 Subject: [PATCH 11/36] add flag to disable automated function calling --- .../main/java/com/google/ai/client/generativeai/Chat.kt | 7 +++++++ .../ai/client/generativeai/type/GenerationConfig.kt | 8 ++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt index 2f586ffa..08697e0e 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt @@ -73,6 +73,9 @@ class Chat(private val model: GenerativeModel, val history: MutableList tempHistory.add(prompt) tempHistory.add(response.candidates.first().content) if (responsePart is FunctionCallPart) { + if (model.generationConfig?.autoFunction == false) { + break + } val output = model.executeFunction(responsePart) prompt = Content("function", listOf(FunctionResponsePart(responsePart.name, output))) } else { @@ -188,6 +191,10 @@ class Chat(private val model: GenerativeModel, val history: MutableList tempHistory.add(Content("model", listOf(part))) } is FunctionCallPart -> { + if (model.generationConfig?.autoFunction == false) { + tempHistory.add(response.candidates.first().content) + continue + } val functionCall = response.candidates.first().content.parts.first { it is FunctionCallPart } as FunctionCallPart diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerationConfig.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerationConfig.kt index f1ff79d7..4dd3000b 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerationConfig.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerationConfig.kt @@ -25,6 +25,7 @@ package com.google.ai.client.generativeai.type * @property candidateCount The max *unique* responses to return * @property maxOutputTokens The max tokens to generate per response * @property stopSequences A list of strings to stop generation on occurrence of + * @property autoFunction if false, auto functions will not be automatically executed */ class GenerationConfig private constructor( @@ -33,7 +34,8 @@ private constructor( val topP: Float?, val candidateCount: Int?, val maxOutputTokens: Int?, - val stopSequences: List? + val stopSequences: List?, + val autoFunction: Boolean? ) { class Builder { @@ -43,6 +45,7 @@ private constructor( @JvmField var candidateCount: Int? = null @JvmField var maxOutputTokens: Int? = null @JvmField var stopSequences: List? = null + @JvmField var autoFunction:Boolean? = null fun build() = GenerationConfig( @@ -51,7 +54,8 @@ private constructor( topP = topP, candidateCount = candidateCount, maxOutputTokens = maxOutputTokens, - stopSequences = stopSequences + stopSequences = stopSequences, + autoFunction = autoFunction ) } From 393f918b9611248c3e27cd33ffd5a33d5491acd1 Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Thu, 22 Feb 2024 10:49:48 -0800 Subject: [PATCH 12/36] move autofunction to request options and other merge conflicts --- .../java/com/google/ai/client/generativeai/Chat.kt | 6 +++--- .../ai/client/generativeai/GenerativeModel.kt | 7 ++++--- .../generativeai/internal/util/conversions.kt | 2 +- .../generativeai/type/FunctionDeclarations.kt | 3 ++- .../ai/client/generativeai/type/GenerationConfig.kt | 4 ---- .../ai/client/generativeai/type/RequestOptions.kt | 13 ++++++++++--- 6 files changed, 20 insertions(+), 15 deletions(-) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt index 08697e0e..7ba4d953 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt @@ -17,12 +17,12 @@ package com.google.ai.client.generativeai import android.graphics.Bitmap -import com.google.ai.client.generativeai.type.GenerativeBeta import com.google.ai.client.generativeai.type.BlobPart import com.google.ai.client.generativeai.type.Content import com.google.ai.client.generativeai.type.FunctionCallPart import com.google.ai.client.generativeai.type.FunctionResponsePart import com.google.ai.client.generativeai.type.GenerateContentResponse +import com.google.ai.client.generativeai.type.GenerativeBeta import com.google.ai.client.generativeai.type.ImagePart import com.google.ai.client.generativeai.type.InvalidStateException import com.google.ai.client.generativeai.type.TextPart @@ -73,7 +73,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList tempHistory.add(prompt) tempHistory.add(response.candidates.first().content) if (responsePart is FunctionCallPart) { - if (model.generationConfig?.autoFunction == false) { + if (!model.requestOptions.autoFunction) { break } val output = model.executeFunction(responsePart) @@ -191,7 +191,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList tempHistory.add(Content("model", listOf(part))) } is FunctionCallPart -> { - if (model.generationConfig?.autoFunction == false) { + if (!model.requestOptions.autoFunction) { tempHistory.add(response.candidates.first().content) continue } diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt index 0c73f4bf..1f9c2b0e 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt @@ -22,7 +22,6 @@ import com.google.ai.client.generativeai.internal.api.CountTokensRequest import com.google.ai.client.generativeai.internal.api.GenerateContentRequest import com.google.ai.client.generativeai.internal.util.toInternal import com.google.ai.client.generativeai.internal.util.toPublic -import com.google.ai.client.generativeai.type.GenerativeBeta import com.google.ai.client.generativeai.type.Content import com.google.ai.client.generativeai.type.CountTokensResponse import com.google.ai.client.generativeai.type.FinishReason @@ -30,6 +29,7 @@ import com.google.ai.client.generativeai.type.FourParameterFunction import com.google.ai.client.generativeai.type.FunctionCallPart import com.google.ai.client.generativeai.type.GenerateContentResponse import com.google.ai.client.generativeai.type.GenerationConfig +import com.google.ai.client.generativeai.type.GenerativeBeta import com.google.ai.client.generativeai.type.GoogleGenerativeAIException import com.google.ai.client.generativeai.type.NoParameterFunction import com.google.ai.client.generativeai.type.OneParameterFunction @@ -74,7 +74,7 @@ internal constructor( apiKey: String, generationConfig: GenerationConfig? = null, safetySettings: List? = null, - tools: List? = null + tools: List? = null, requestOptions: RequestOptions = RequestOptions(), ) : this( modelName, @@ -201,7 +201,8 @@ internal constructor( ?: throw RuntimeException("No registered function named ${functionCallPart.name}") return when (callable) { is NoParameterFunction -> callable.execute() - is OneParameterFunction<*> -> (callable as OneParameterFunction).execute(functionCallPart) + is OneParameterFunction<*> -> + (callable as OneParameterFunction).execute(functionCallPart) is TwoParameterFunction<*, *> -> (callable as TwoParameterFunction).execute(functionCallPart) is ThreeParameterFunction<*, *, *> -> diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt index fc730cd8..cc19a0cb 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt @@ -42,10 +42,10 @@ import com.google.ai.client.generativeai.internal.api.shared.HarmCategory import com.google.ai.client.generativeai.internal.api.shared.Part import com.google.ai.client.generativeai.internal.api.shared.SafetySetting import com.google.ai.client.generativeai.internal.api.shared.TextPart -import com.google.ai.client.generativeai.type.GenerativeBeta import com.google.ai.client.generativeai.type.BlockThreshold import com.google.ai.client.generativeai.type.CitationMetadata import com.google.ai.client.generativeai.type.FunctionDeclaration +import com.google.ai.client.generativeai.type.GenerativeBeta import com.google.ai.client.generativeai.type.ImagePart import com.google.ai.client.generativeai.type.SerializationException import com.google.ai.client.generativeai.type.Tool diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt index 6b5e073d..5396483b 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt @@ -142,7 +142,7 @@ class FourParameterFunction( ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf(param1, param2, param3, param4) - override suspend fun execute(part: FunctionCallPart): String { + override suspend fun execute(part: FunctionCallPart): String { val arg1 = part.getArgOrThrow(param1) val arg2 = part.getArgOrThrow(param2) val arg3 = part.getArgOrThrow(param3) @@ -157,6 +157,7 @@ abstract class FunctionDeclaration( val description: String, ) { abstract fun getParameters(): List> + abstract suspend fun execute(part: FunctionCallPart): String } diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerationConfig.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerationConfig.kt index 4dd3000b..972b51e0 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerationConfig.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerationConfig.kt @@ -25,7 +25,6 @@ package com.google.ai.client.generativeai.type * @property candidateCount The max *unique* responses to return * @property maxOutputTokens The max tokens to generate per response * @property stopSequences A list of strings to stop generation on occurrence of - * @property autoFunction if false, auto functions will not be automatically executed */ class GenerationConfig private constructor( @@ -35,7 +34,6 @@ private constructor( val candidateCount: Int?, val maxOutputTokens: Int?, val stopSequences: List?, - val autoFunction: Boolean? ) { class Builder { @@ -45,7 +43,6 @@ private constructor( @JvmField var candidateCount: Int? = null @JvmField var maxOutputTokens: Int? = null @JvmField var stopSequences: List? = null - @JvmField var autoFunction:Boolean? = null fun build() = GenerationConfig( @@ -55,7 +52,6 @@ private constructor( candidateCount = candidateCount, maxOutputTokens = maxOutputTokens, stopSequences = stopSequences, - autoFunction = autoFunction ) } diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt index cc9669d9..e1074dc3 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt @@ -27,14 +27,21 @@ import kotlin.time.toDuration * @property timeout the maximum amount of time for a request to take, from the first request to * first response. * @property apiVersion the api endpoint to call. + * @property autoFunction if false, auto functions will not be automatically executed */ -class RequestOptions(val timeout: Duration, val apiVersion: String = "v1") { +class RequestOptions( + val timeout: Duration, + val apiVersion: String = "v1", + val autoFunction: Boolean = true +) { @JvmOverloads constructor( timeout: Long? = HttpTimeout.INFINITE_TIMEOUT_MS, - apiVersion: String = "v1" + apiVersion: String = "v1", + autoFunction: Boolean = true ) : this( (timeout ?: HttpTimeout.INFINITE_TIMEOUT_MS).toDuration(DurationUnit.MILLISECONDS), - apiVersion + apiVersion, + autoFunction ) } From 931cf003487942c734e5d617b1b0a9d74dc46ec8 Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Thu, 22 Feb 2024 10:51:43 -0800 Subject: [PATCH 13/36] add license to Type.kt --- .../google/ai/client/generativeai/type/Type.kt | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Type.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Type.kt index 53f0dce2..6e94728f 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Type.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Type.kt @@ -1,3 +1,19 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package com.google.ai.client.generativeai.type class FunctionType(val name: String, val parse: (String?) -> T?) { From 3ff810598ab8598358366ab031f818c85ac399a6 Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Fri, 23 Feb 2024 10:28:11 -0800 Subject: [PATCH 14/36] swap autoFunctionFlag to the negative --- .../src/main/java/com/google/ai/client/generativeai/Chat.kt | 4 ++-- .../com/google/ai/client/generativeai/type/RequestOptions.kt | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt index 7ba4d953..7131c7de 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt @@ -73,7 +73,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList tempHistory.add(prompt) tempHistory.add(response.candidates.first().content) if (responsePart is FunctionCallPart) { - if (!model.requestOptions.autoFunction) { + if (!model.requestOptions.disableAutoFunction) { break } val output = model.executeFunction(responsePart) @@ -191,7 +191,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList tempHistory.add(Content("model", listOf(part))) } is FunctionCallPart -> { - if (!model.requestOptions.autoFunction) { + if (!model.requestOptions.disableAutoFunction) { tempHistory.add(response.candidates.first().content) continue } diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt index e1074dc3..88164ec4 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt @@ -27,12 +27,12 @@ import kotlin.time.toDuration * @property timeout the maximum amount of time for a request to take, from the first request to * first response. * @property apiVersion the api endpoint to call. - * @property autoFunction if false, auto functions will not be automatically executed + * @property disableAutoFunction if true, auto functions will not be automatically executed */ class RequestOptions( val timeout: Duration, val apiVersion: String = "v1", - val autoFunction: Boolean = true + val disableAutoFunction: Boolean = false ) { @JvmOverloads constructor( From e9c3adc85cc443f63b16e05b812437f0e06acf06 Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Fri, 23 Feb 2024 10:29:03 -0800 Subject: [PATCH 15/36] fix dependant if statements --- .../src/main/java/com/google/ai/client/generativeai/Chat.kt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt index 7131c7de..48f2058e 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt @@ -73,7 +73,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList tempHistory.add(prompt) tempHistory.add(response.candidates.first().content) if (responsePart is FunctionCallPart) { - if (!model.requestOptions.disableAutoFunction) { + if (model.requestOptions.disableAutoFunction) { break } val output = model.executeFunction(responsePart) @@ -191,7 +191,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList tempHistory.add(Content("model", listOf(part))) } is FunctionCallPart -> { - if (!model.requestOptions.disableAutoFunction) { + if (model.requestOptions.disableAutoFunction) { tempHistory.add(response.candidates.first().content) continue } From 6c4eb818219afd5868d686121d32bfac39cda441 Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Tue, 5 Mar 2024 12:51:48 -0800 Subject: [PATCH 16/36] swap output to JsonObject --- .../ai/client/generativeai/GenerativeModel.kt | 3 +- .../generativeai/internal/api/shared/Types.kt | 6 ++-- .../generativeai/internal/util/conversions.kt | 5 ++- .../generativeai/type/FunctionDeclarations.kt | 32 ++++++++++--------- .../ai/client/generativeai/type/Part.kt | 3 +- 5 files changed, 26 insertions(+), 23 deletions(-) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt index 1f9c2b0e..8eaf24da 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt @@ -45,6 +45,7 @@ import com.google.ai.client.generativeai.type.content import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.catch import kotlinx.coroutines.flow.map +import kotlinx.serialization.json.JsonObject /** * A facilitator for a given multimodal model (eg; Gemini). @@ -191,7 +192,7 @@ internal constructor( * @return The output of the requested function call */ @GenerativeBeta - suspend fun executeFunction(functionCallPart: FunctionCallPart): String { + suspend fun executeFunction(functionCallPart: FunctionCallPart): JsonObject { if (tools == null) { throw RuntimeException("No registered tools") } diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/shared/Types.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/shared/Types.kt index 1993d8d1..2dd191d0 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/shared/Types.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/shared/Types.kt @@ -24,7 +24,9 @@ import kotlinx.serialization.Serializable import kotlinx.serialization.SerializationException import kotlinx.serialization.json.JsonContentPolymorphicSerializer import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.jsonObject +import org.json.JSONObject internal object HarmCategorySerializer : KSerializer by FirstOrdinalSerializer(HarmCategory::class) @@ -60,12 +62,10 @@ internal data class Blob( ) @Serializable -internal data class FunctionResponse(val name: String, val response: FunctionResponseData) +internal data class FunctionResponse(val name: String, val response: JsonObject) @Serializable internal data class FunctionCall(val name: String, val args: Map) -@Serializable internal data class FunctionResponseData(val name: String, val content: String) - @Serializable internal data class SafetySetting(val category: HarmCategory, val threshold: HarmBlockThreshold) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt index cc19a0cb..503c90b5 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt @@ -35,7 +35,6 @@ import com.google.ai.client.generativeai.internal.api.shared.Content import com.google.ai.client.generativeai.internal.api.shared.FunctionCall import com.google.ai.client.generativeai.internal.api.shared.FunctionCallPart import com.google.ai.client.generativeai.internal.api.shared.FunctionResponse -import com.google.ai.client.generativeai.internal.api.shared.FunctionResponseData import com.google.ai.client.generativeai.internal.api.shared.FunctionResponsePart import com.google.ai.client.generativeai.internal.api.shared.HarmBlockThreshold import com.google.ai.client.generativeai.internal.api.shared.HarmCategory @@ -68,7 +67,7 @@ internal fun com.google.ai.client.generativeai.type.Part.toInternal(): Part { is com.google.ai.client.generativeai.type.FunctionCallPart -> FunctionCallPart(FunctionCall(name, args)) is com.google.ai.client.generativeai.type.FunctionResponsePart -> - FunctionResponsePart(FunctionResponse(name, FunctionResponseData(name, response))) + FunctionResponsePart(FunctionResponse(name, response)) else -> throw SerializationException( "The given subclass of Part (${javaClass.simpleName}) is not supported in the serialization yet." @@ -175,7 +174,7 @@ internal fun Part.toPublic(): com.google.ai.client.generativeai.type.Part { is FunctionResponsePart -> com.google.ai.client.generativeai.type.FunctionResponsePart( functionResponse.name, - functionResponse.response.content + functionResponse.response ) } } diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt index 5396483b..5bce3c7c 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt @@ -16,6 +16,8 @@ package com.google.ai.client.generativeai.type +import kotlinx.serialization.json.JsonObject + /** * A declared function, including implementation, that a model can be given access to in order to * gain info or complete tasks. @@ -28,7 +30,7 @@ package com.google.ai.client.generativeai.type class NoParameterFunction( name: String, description: String, - val function: suspend () -> String, + val function: suspend () -> JsonObject, ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf>() @@ -51,11 +53,11 @@ class OneParameterFunction( name: String, description: String, val param: ParameterDeclaration, - val function: suspend (T) -> String, + val function: suspend (T) -> JsonObject, ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf(param) - override suspend fun execute(part: FunctionCallPart): String { + override suspend fun execute(part: FunctionCallPart): JsonObject { val arg1 = part.getArgOrThrow(param) return function(arg1) } @@ -77,11 +79,11 @@ class TwoParameterFunction( description: String, val param1: ParameterDeclaration, val param2: ParameterDeclaration, - val function: suspend (T, U) -> String, + val function: suspend (T, U) -> JsonObject, ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf(param1, param2) - override suspend fun execute(part: FunctionCallPart): String { + override suspend fun execute(part: FunctionCallPart): JsonObject { val arg1 = part.getArgOrThrow(param1) val arg2 = part.getArgOrThrow(param2) return function(arg1, arg2) @@ -106,11 +108,11 @@ class ThreeParameterFunction( val param1: ParameterDeclaration, val param2: ParameterDeclaration, val param3: ParameterDeclaration, - val function: suspend (T, U, V) -> String, + val function: suspend (T, U, V) -> JsonObject, ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf(param1, param2, param3) - override suspend fun execute(part: FunctionCallPart): String { + override suspend fun execute(part: FunctionCallPart): JsonObject { val arg1 = part.getArgOrThrow(param1) val arg2 = part.getArgOrThrow(param2) val arg3 = part.getArgOrThrow(param3) @@ -138,11 +140,11 @@ class FourParameterFunction( val param2: ParameterDeclaration, val param3: ParameterDeclaration, val param4: ParameterDeclaration, - val function: suspend (T, U, V, W) -> String, + val function: suspend (T, U, V, W) -> JsonObject, ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf(param1, param2, param3, param4) - override suspend fun execute(part: FunctionCallPart): String { + override suspend fun execute(part: FunctionCallPart): JsonObject { val arg1 = part.getArgOrThrow(param1) val arg2 = part.getArgOrThrow(param2) val arg3 = part.getArgOrThrow(param3) @@ -158,7 +160,7 @@ abstract class FunctionDeclaration( ) { abstract fun getParameters(): List> - abstract suspend fun execute(part: FunctionCallPart): String + abstract suspend fun execute(part: FunctionCallPart): JsonObject } class ParameterDeclaration( @@ -181,7 +183,7 @@ class ParameterDeclaration( } @GenerativeBeta -fun defineFunction(name: String, description: String, function: suspend () -> String) = +fun defineFunction(name: String, description: String, function: suspend () -> JsonObject) = NoParameterFunction(name, description, function) @GenerativeBeta @@ -189,7 +191,7 @@ fun defineFunction( name: String, description: String, arg1: ParameterDeclaration, - function: suspend (T) -> String + function: suspend (T) -> JsonObject ) = OneParameterFunction(name, description, arg1, function) @GenerativeBeta @@ -198,7 +200,7 @@ fun defineFunction( description: String, arg1: ParameterDeclaration, arg2: ParameterDeclaration, - function: suspend (T, U) -> String + function: suspend (T, U) -> JsonObject ) = TwoParameterFunction(name, description, arg1, arg2, function) @GenerativeBeta @@ -208,7 +210,7 @@ fun defineFunction( arg1: ParameterDeclaration, arg2: ParameterDeclaration, arg3: ParameterDeclaration, - function: suspend (T, U, W) -> String + function: suspend (T, U, W) -> JsonObject ) = ThreeParameterFunction(name, description, arg1, arg2, arg3, function) @GenerativeBeta @@ -219,7 +221,7 @@ fun defineFunction( arg2: ParameterDeclaration, arg3: ParameterDeclaration, arg4: ParameterDeclaration, - function: suspend (T, U, W, Z) -> String + function: suspend (T, U, W, Z) -> JsonObject ) = FourParameterFunction(name, description, arg1, arg2, arg3, arg4, function) private fun FunctionCallPart.getArgOrThrow(param: ParameterDeclaration): T { diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt index ebdde06b..65e5714e 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt @@ -17,6 +17,7 @@ package com.google.ai.client.generativeai.type import android.graphics.Bitmap +import kotlinx.serialization.json.JsonObject /** * Interface representing data sent to and received from requests. @@ -44,7 +45,7 @@ class BlobPart(val mimeType: String, val blob: ByteArray) : Part class FunctionCallPart(val name: String, val args: Map) : Part /** Represents function call output to be returned to the model when it requests a function call */ -class FunctionResponsePart(val name: String, val response: String) : Part +class FunctionResponsePart(val name: String, val response: JsonObject) : Part /** @return The part as a [String] if it represents text, and null otherwise */ fun Part.asTextOrNull(): String? = (this as? TextPart)?.text From 6a50c767b2c1c14e9345ed703e3d78f072bc1f6e Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Tue, 5 Mar 2024 13:14:32 -0800 Subject: [PATCH 17/36] swap public implementation to org.json --- .../ai/client/generativeai/GenerativeModel.kt | 4 +-- .../generativeai/internal/api/shared/Types.kt | 4 +-- .../generativeai/internal/util/conversions.kt | 12 +++++-- .../generativeai/type/FunctionDeclarations.kt | 32 +++++++++---------- .../ai/client/generativeai/type/Part.kt | 4 +-- 5 files changed, 31 insertions(+), 25 deletions(-) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt index 8eaf24da..05c6d49e 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt @@ -45,7 +45,7 @@ import com.google.ai.client.generativeai.type.content import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.catch import kotlinx.coroutines.flow.map -import kotlinx.serialization.json.JsonObject +import org.json.JSONObject /** * A facilitator for a given multimodal model (eg; Gemini). @@ -192,7 +192,7 @@ internal constructor( * @return The output of the requested function call */ @GenerativeBeta - suspend fun executeFunction(functionCallPart: FunctionCallPart): JsonObject { + suspend fun executeFunction(functionCallPart: FunctionCallPart): JSONObject { if (tools == null) { throw RuntimeException("No registered tools") } diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/shared/Types.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/shared/Types.kt index 2dd191d0..80eecac3 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/shared/Types.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/shared/Types.kt @@ -26,7 +26,6 @@ import kotlinx.serialization.json.JsonContentPolymorphicSerializer import kotlinx.serialization.json.JsonElement import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.jsonObject -import org.json.JSONObject internal object HarmCategorySerializer : KSerializer by FirstOrdinalSerializer(HarmCategory::class) @@ -61,8 +60,7 @@ internal data class Blob( val data: Base64, ) -@Serializable -internal data class FunctionResponse(val name: String, val response: JsonObject) +@Serializable internal data class FunctionResponse(val name: String, val response: JsonObject) @Serializable internal data class FunctionCall(val name: String, val args: Map) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt index 503c90b5..9c5640a1 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt @@ -50,8 +50,12 @@ import com.google.ai.client.generativeai.type.SerializationException import com.google.ai.client.generativeai.type.Tool import com.google.ai.client.generativeai.type.content import java.io.ByteArrayOutputStream +import kotlinx.serialization.decodeFromString +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.JsonPrimitive import kotlinx.serialization.json.buildJsonObject +import org.json.JSONObject private const val BASE_64_FLAGS = Base64.NO_WRAP @@ -67,7 +71,7 @@ internal fun com.google.ai.client.generativeai.type.Part.toInternal(): Part { is com.google.ai.client.generativeai.type.FunctionCallPart -> FunctionCallPart(FunctionCall(name, args)) is com.google.ai.client.generativeai.type.FunctionResponsePart -> - FunctionResponsePart(FunctionResponse(name, response)) + FunctionResponsePart(FunctionResponse(name, response.toInternal())) else -> throw SerializationException( "The given subclass of Part (${javaClass.simpleName}) is not supported in the serialization yet." @@ -139,6 +143,8 @@ internal fun FunctionDeclaration.toInternal(): ) } +internal fun JSONObject.toInternal() = Json.decodeFromString(toString()) + internal fun Candidate.toPublic(): com.google.ai.client.generativeai.type.Candidate { val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty() val citations = citationMetadata?.citationSources?.map { it.toPublic() }.orEmpty() @@ -174,7 +180,7 @@ internal fun Part.toPublic(): com.google.ai.client.generativeai.type.Part { is FunctionResponsePart -> com.google.ai.client.generativeai.type.FunctionResponsePart( functionResponse.name, - functionResponse.response + functionResponse.response.toPublic() ) } } @@ -244,6 +250,8 @@ internal fun GenerateContentResponse.toPublic() = internal fun CountTokensResponse.toPublic() = com.google.ai.client.generativeai.type.CountTokensResponse(totalTokens) +internal fun JsonObject.toPublic() = JSONObject(toString()) + private fun encodeBitmapToBase64Png(input: Bitmap): String { ByteArrayOutputStream().let { input.compress(Bitmap.CompressFormat.JPEG, 80, it) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt index 5bce3c7c..b1892c47 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt @@ -16,7 +16,7 @@ package com.google.ai.client.generativeai.type -import kotlinx.serialization.json.JsonObject +import org.json.JSONObject /** * A declared function, including implementation, that a model can be given access to in order to @@ -30,7 +30,7 @@ import kotlinx.serialization.json.JsonObject class NoParameterFunction( name: String, description: String, - val function: suspend () -> JsonObject, + val function: suspend () -> JSONObject, ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf>() @@ -53,11 +53,11 @@ class OneParameterFunction( name: String, description: String, val param: ParameterDeclaration, - val function: suspend (T) -> JsonObject, + val function: suspend (T) -> JSONObject, ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf(param) - override suspend fun execute(part: FunctionCallPart): JsonObject { + override suspend fun execute(part: FunctionCallPart): JSONObject { val arg1 = part.getArgOrThrow(param) return function(arg1) } @@ -79,11 +79,11 @@ class TwoParameterFunction( description: String, val param1: ParameterDeclaration, val param2: ParameterDeclaration, - val function: suspend (T, U) -> JsonObject, + val function: suspend (T, U) -> JSONObject, ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf(param1, param2) - override suspend fun execute(part: FunctionCallPart): JsonObject { + override suspend fun execute(part: FunctionCallPart): JSONObject { val arg1 = part.getArgOrThrow(param1) val arg2 = part.getArgOrThrow(param2) return function(arg1, arg2) @@ -108,11 +108,11 @@ class ThreeParameterFunction( val param1: ParameterDeclaration, val param2: ParameterDeclaration, val param3: ParameterDeclaration, - val function: suspend (T, U, V) -> JsonObject, + val function: suspend (T, U, V) -> JSONObject, ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf(param1, param2, param3) - override suspend fun execute(part: FunctionCallPart): JsonObject { + override suspend fun execute(part: FunctionCallPart): JSONObject { val arg1 = part.getArgOrThrow(param1) val arg2 = part.getArgOrThrow(param2) val arg3 = part.getArgOrThrow(param3) @@ -140,11 +140,11 @@ class FourParameterFunction( val param2: ParameterDeclaration, val param3: ParameterDeclaration, val param4: ParameterDeclaration, - val function: suspend (T, U, V, W) -> JsonObject, + val function: suspend (T, U, V, W) -> JSONObject, ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf(param1, param2, param3, param4) - override suspend fun execute(part: FunctionCallPart): JsonObject { + override suspend fun execute(part: FunctionCallPart): JSONObject { val arg1 = part.getArgOrThrow(param1) val arg2 = part.getArgOrThrow(param2) val arg3 = part.getArgOrThrow(param3) @@ -160,7 +160,7 @@ abstract class FunctionDeclaration( ) { abstract fun getParameters(): List> - abstract suspend fun execute(part: FunctionCallPart): JsonObject + abstract suspend fun execute(part: FunctionCallPart): JSONObject } class ParameterDeclaration( @@ -183,7 +183,7 @@ class ParameterDeclaration( } @GenerativeBeta -fun defineFunction(name: String, description: String, function: suspend () -> JsonObject) = +fun defineFunction(name: String, description: String, function: suspend () -> JSONObject) = NoParameterFunction(name, description, function) @GenerativeBeta @@ -191,7 +191,7 @@ fun defineFunction( name: String, description: String, arg1: ParameterDeclaration, - function: suspend (T) -> JsonObject + function: suspend (T) -> JSONObject ) = OneParameterFunction(name, description, arg1, function) @GenerativeBeta @@ -200,7 +200,7 @@ fun defineFunction( description: String, arg1: ParameterDeclaration, arg2: ParameterDeclaration, - function: suspend (T, U) -> JsonObject + function: suspend (T, U) -> JSONObject ) = TwoParameterFunction(name, description, arg1, arg2, function) @GenerativeBeta @@ -210,7 +210,7 @@ fun defineFunction( arg1: ParameterDeclaration, arg2: ParameterDeclaration, arg3: ParameterDeclaration, - function: suspend (T, U, W) -> JsonObject + function: suspend (T, U, W) -> JSONObject ) = ThreeParameterFunction(name, description, arg1, arg2, arg3, function) @GenerativeBeta @@ -221,7 +221,7 @@ fun defineFunction( arg2: ParameterDeclaration, arg3: ParameterDeclaration, arg4: ParameterDeclaration, - function: suspend (T, U, W, Z) -> JsonObject + function: suspend (T, U, W, Z) -> JSONObject ) = FourParameterFunction(name, description, arg1, arg2, arg3, arg4, function) private fun FunctionCallPart.getArgOrThrow(param: ParameterDeclaration): T { diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt index 65e5714e..83f22953 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt @@ -17,7 +17,7 @@ package com.google.ai.client.generativeai.type import android.graphics.Bitmap -import kotlinx.serialization.json.JsonObject +import org.json.JSONObject /** * Interface representing data sent to and received from requests. @@ -45,7 +45,7 @@ class BlobPart(val mimeType: String, val blob: ByteArray) : Part class FunctionCallPart(val name: String, val args: Map) : Part /** Represents function call output to be returned to the model when it requests a function call */ -class FunctionResponsePart(val name: String, val response: JsonObject) : Part +class FunctionResponsePart(val name: String, val response: JSONObject) : Part /** @return The part as a [String] if it represents text, and null otherwise */ fun Part.asTextOrNull(): String? = (this as? TextPart)?.text From 6ad88696179c1fa546e4204138327a766114e153 Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Tue, 5 Mar 2024 13:21:10 -0800 Subject: [PATCH 18/36] fix request options default --- .../com/google/ai/client/generativeai/type/RequestOptions.kt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt index 88164ec4..5b5d83d0 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt @@ -38,10 +38,10 @@ class RequestOptions( constructor( timeout: Long? = HttpTimeout.INFINITE_TIMEOUT_MS, apiVersion: String = "v1", - autoFunction: Boolean = true + disableAutoFunction: Boolean = false ) : this( (timeout ?: HttpTimeout.INFINITE_TIMEOUT_MS).toDuration(DurationUnit.MILLISECONDS), apiVersion, - autoFunction + disableAutoFunction ) } From 266b335ef5c0823e3ea553dff350dda6abbb3df3 Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Mon, 11 Mar 2024 15:23:11 -0700 Subject: [PATCH 19/36] Update generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt Co-authored-by: Daymon <17409137+daymxn@users.noreply.github.com> --- .../java/com/google/ai/client/generativeai/Chat.kt | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt index e50d0403..5247d447 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt @@ -72,15 +72,11 @@ class Chat(private val model: GenerativeModel, val history: MutableList tempHistory.add(prompt) tempHistory.add(response.candidates.first().content) - if (responsePart is FunctionCallPart) { - if (model.requestOptions.disableAutoFunction) { - break - } - val output = model.executeFunction(responsePart) - prompt = Content("function", listOf(FunctionResponsePart(responsePart.name, output))) - } else { - break - } + if (responsePart !is FunctionCallPart) break + if (model.requestOptions.disableAutoFunction) break + + val output = model.executeFunction(responsePart) + prompt = Content("function", listOf(FunctionResponsePart(responsePart.name, output))) } history.addAll(tempHistory) return response From 6684478fd99bcd81580694e4a7043b259790387b Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Mon, 11 Mar 2024 15:23:19 -0700 Subject: [PATCH 20/36] Update generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt Co-authored-by: Daymon <17409137+daymxn@users.noreply.github.com> --- .../src/main/java/com/google/ai/client/generativeai/Chat.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt index 5247d447..e8ac37aa 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt @@ -120,7 +120,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList attemptLock() val flow = model.generateContentStream(*history.toTypedArray(), prompt) - val tempHistory = LinkedList() + val tempHistory = mutableListOf() tempHistory.add(prompt) /** * TODO: revisit when images and blobs are returned. This will cause issues with how things are From 7c57e129e86f751fbe2be1daa3a34cf23166674d Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Tue, 12 Mar 2024 14:51:48 -0700 Subject: [PATCH 21/36] Add extra datatypes and fix serialization --- .../com/google/ai/client/generativeai/Chat.kt | 4 +- .../generativeai/internal/util/conversions.kt | 29 ++++++++------ .../generativeai/type/FunctionDeclarations.kt | 39 ++++++++++++------- .../ai/client/generativeai/type/Type.kt | 12 +++++- 4 files changed, 55 insertions(+), 29 deletions(-) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt index e8ac37aa..cf3654ac 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt @@ -169,7 +169,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList private suspend fun automaticFunctionExecutingTransform( transformer: FlowCollector, - tempHistory: LinkedList, + tempHistory: MutableList, response: GenerateContentResponse ) { for (part in response.candidates.first().content.parts) { @@ -207,7 +207,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList } } - private fun addTextToHistory(tempHistory: LinkedList, textPart: TextPart) { + private fun addTextToHistory(tempHistory: MutableList, textPart: TextPart) { val lastContent = tempHistory.lastOrNull() if (lastContent?.role == "model" && lastContent.parts.any { it is TextPart }) { tempHistory.removeLast() diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt index 9c5640a1..1c700b69 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt @@ -52,6 +52,7 @@ import com.google.ai.client.generativeai.type.content import java.io.ByteArrayOutputStream import kotlinx.serialization.decodeFromString import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonArray import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.JsonPrimitive import kotlinx.serialization.json.buildJsonObject @@ -89,7 +90,7 @@ internal fun com.google.ai.client.generativeai.type.GenerationConfig.toInternal( topK = topK, candidateCount = candidateCount, maxOutputTokens = maxOutputTokens, - stopSequences = stopSequences + stopSequences = stopSequences, ) internal fun com.google.ai.client.generativeai.type.HarmCategory.toInternal() = @@ -122,13 +123,17 @@ internal fun Tool.toInternal() = internal fun FunctionDeclaration.toInternal(): com.google.ai.client.generativeai.internal.api.client.FunctionDeclaration { val convertedParams = buildJsonObject { - getParameters().forEach { + getParameters().forEach { pramDeclaration -> put( - it.name, + pramDeclaration.name, buildJsonObject { - put("type", JsonPrimitive("STRING")) - put("description", JsonPrimitive(it.description)) - } + put("type", JsonPrimitive(pramDeclaration.type.name)) + put("description", JsonPrimitive(pramDeclaration.description)) + pramDeclaration.format?.let { format -> put("format", JsonPrimitive(format)) } + pramDeclaration.enum?.let { enum -> + put("enum", JsonArray(enum.map { JsonPrimitive(it) })) + } + }, ) } } @@ -138,8 +143,8 @@ internal fun FunctionDeclaration.toInternal(): com.google.ai.client.generativeai.internal.api.client.FunctionParameters( convertedParams, getParameters().map { it.name }, - "OBJECT" - ) + "OBJECT", + ), ) } @@ -154,7 +159,7 @@ internal fun Candidate.toPublic(): com.google.ai.client.generativeai.type.Candid this.content?.toPublic() ?: content("model") {}, safetyRatings, citations, - finishReason + finishReason, ) } @@ -175,12 +180,12 @@ internal fun Part.toPublic(): com.google.ai.client.generativeai.type.Part { is FunctionCallPart -> com.google.ai.client.generativeai.type.FunctionCallPart( functionCall.name, - functionCall.args.orEmpty() + functionCall.args.orEmpty(), ) is FunctionResponsePart -> com.google.ai.client.generativeai.type.FunctionResponsePart( functionResponse.name, - functionResponse.response.toPublic() + functionResponse.response.toPublic(), ) } } @@ -244,7 +249,7 @@ internal fun BlockReason.toPublic() = internal fun GenerateContentResponse.toPublic() = com.google.ai.client.generativeai.type.GenerateContentResponse( candidates?.map { it.toPublic() }.orEmpty(), - promptFeedback?.toPublic() + promptFeedback?.toPublic(), ) internal fun CountTokensResponse.toPublic() = diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt index b1892c47..ea292bb2 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt @@ -154,10 +154,7 @@ class FourParameterFunction( } @GenerativeBeta -abstract class FunctionDeclaration( - val name: String, - val description: String, -) { +abstract class FunctionDeclaration(val name: String, val description: String) { abstract fun getParameters(): List> abstract suspend fun execute(part: FunctionCallPart): JSONObject @@ -166,19 +163,33 @@ abstract class FunctionDeclaration( class ParameterDeclaration( val name: String, val description: String, - private val type: FunctionType + val format: String? = null, + val enum: List? = null, + val type: FunctionType, ) { fun fromString(value: String?) = type.parse(value) companion object { fun int(name: String, description: String) = - ParameterDeclaration(name, description, FunctionType.INT) + ParameterDeclaration(name, description, null, null, FunctionType.INTEGER) + + fun str(name: String, description: String) = + ParameterDeclaration(name, description, null, null, FunctionType.STRING) + + fun bool(name: String, description: String) = + ParameterDeclaration(name, description, null, null, FunctionType.BOOLEAN) + + fun num(name: String, description: String) = + ParameterDeclaration(name, description, null, null, FunctionType.NUMBER) + + fun obj(name: String, description: String) = + ParameterDeclaration(name, description, null, null, FunctionType.OBJECT) - fun string(name: String, description: String) = - ParameterDeclaration(name, description, FunctionType.STRING) + fun arr(name: String, description: String) = + ParameterDeclaration>(name, description, null, null, FunctionType.ARRAY) - fun boolean(name: String, description: String) = - ParameterDeclaration(name, description, FunctionType.BOOLEAN) + fun enum(name: String, description: String, values: List) = + ParameterDeclaration(name, description, "enum", values, FunctionType.STRING) } } @@ -191,7 +202,7 @@ fun defineFunction( name: String, description: String, arg1: ParameterDeclaration, - function: suspend (T) -> JSONObject + function: suspend (T) -> JSONObject, ) = OneParameterFunction(name, description, arg1, function) @GenerativeBeta @@ -200,7 +211,7 @@ fun defineFunction( description: String, arg1: ParameterDeclaration, arg2: ParameterDeclaration, - function: suspend (T, U) -> JSONObject + function: suspend (T, U) -> JSONObject, ) = TwoParameterFunction(name, description, arg1, arg2, function) @GenerativeBeta @@ -210,7 +221,7 @@ fun defineFunction( arg1: ParameterDeclaration, arg2: ParameterDeclaration, arg3: ParameterDeclaration, - function: suspend (T, U, W) -> JSONObject + function: suspend (T, U, W) -> JSONObject, ) = ThreeParameterFunction(name, description, arg1, arg2, arg3, function) @GenerativeBeta @@ -221,7 +232,7 @@ fun defineFunction( arg2: ParameterDeclaration, arg3: ParameterDeclaration, arg4: ParameterDeclaration, - function: suspend (T, U, W, Z) -> JSONObject + function: suspend (T, U, W, Z) -> JSONObject, ) = FourParameterFunction(name, description, arg1, arg2, arg3, arg4, function) private fun FunctionCallPart.getArgOrThrow(param: ParameterDeclaration): T { diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Type.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Type.kt index 6e94728f..c8bb9347 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Type.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Type.kt @@ -16,10 +16,20 @@ package com.google.ai.client.generativeai.type +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.jsonArray +import org.json.JSONObject + class FunctionType(val name: String, val parse: (String?) -> T?) { companion object { val STRING = FunctionType("STRING") { it } - val INT = FunctionType("INTEGER") { it?.toIntOrNull() } + val INTEGER = FunctionType("INTEGER") { it?.toLongOrNull() } + val NUMBER = FunctionType("NUMBER") { it?.toDoubleOrNull() } val BOOLEAN = FunctionType("BOOLEAN") { it?.toBoolean() } + val ARRAY = + FunctionType>("ARRAY") { it -> + it?.let { Json.parseToJsonElement(it).jsonArray.map { element -> element.toString() } } + } + val OBJECT = FunctionType("OBJECT") { it?.let { JSONObject(it) } } } } From 571308fbd1b6f958012e336c2603445d001be012 Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Tue, 12 Mar 2024 14:53:18 -0700 Subject: [PATCH 22/36] update copyright --- .github/hooks/pre-push | 2 +- .github/hooks/pre-push.d/formatting.sh | 2 +- .github/hooks/pre-push.d/licensing.sh | 2 +- INSTALL_HOOKS.sh | 2 +- build.gradle.kts | 2 +- change | 2 +- generativeai-android-sample/.google/packaging.yaml | 2 +- generativeai-android-sample/app/build.gradle.kts | 2 +- generativeai-android-sample/app/src/main/AndroidManifest.xml | 2 +- .../kotlin/com/google/ai/sample/GenerativeAiViewModelFactory.kt | 2 +- .../app/src/main/kotlin/com/google/ai/sample/MainActivity.kt | 2 +- .../app/src/main/kotlin/com/google/ai/sample/MenuScreen.kt | 2 +- .../kotlin/com/google/ai/sample/feature/chat/ChatMessage.kt | 2 +- .../main/kotlin/com/google/ai/sample/feature/chat/ChatScreen.kt | 2 +- .../kotlin/com/google/ai/sample/feature/chat/ChatUiState.kt | 2 +- .../kotlin/com/google/ai/sample/feature/chat/ChatViewModel.kt | 2 +- .../google/ai/sample/feature/multimodal/PhotoReasoningScreen.kt | 2 +- .../ai/sample/feature/multimodal/PhotoReasoningUiState.kt | 2 +- .../ai/sample/feature/multimodal/PhotoReasoningViewModel.kt | 2 +- .../kotlin/com/google/ai/sample/feature/text/SummarizeScreen.kt | 2 +- .../com/google/ai/sample/feature/text/SummarizeUiState.kt | 2 +- .../com/google/ai/sample/feature/text/SummarizeViewModel.kt | 2 +- .../app/src/main/kotlin/com/google/ai/sample/ui/theme/Color.kt | 2 +- .../app/src/main/kotlin/com/google/ai/sample/ui/theme/Theme.kt | 2 +- .../app/src/main/kotlin/com/google/ai/sample/ui/theme/Type.kt | 2 +- .../app/src/main/kotlin/com/google/ai/sample/util/UriSaver.kt | 2 +- .../app/src/main/res/drawable/ic_launcher_background.xml | 2 +- .../app/src/main/res/drawable/ic_launcher_foreground.xml | 2 +- .../app/src/main/res/mipmap-anydpi/ic_launcher.xml | 2 +- .../app/src/main/res/mipmap-anydpi/ic_launcher_round.xml | 2 +- generativeai-android-sample/app/src/main/res/values/colors.xml | 2 +- generativeai-android-sample/app/src/main/res/values/strings.xml | 2 +- generativeai-android-sample/app/src/main/res/values/themes.xml | 2 +- .../app/src/main/res/xml/backup_rules.xml | 2 +- .../app/src/main/res/xml/data_extraction_rules.xml | 2 +- generativeai-android-sample/build.gradle.kts | 2 +- generativeai-android-sample/settings.gradle.kts | 2 +- generativeai/build.gradle.kts | 2 +- .../google/ai/client/generativeai/ExampleInstrumentedTest.kt | 2 +- generativeai/src/main/AndroidManifest.xml | 2 +- .../src/main/java/com/google/ai/client/generativeai/Chat.kt | 2 +- .../java/com/google/ai/client/generativeai/GenerativeModel.kt | 2 +- .../google/ai/client/generativeai/internal/api/APIController.kt | 2 +- .../com/google/ai/client/generativeai/internal/api/Request.kt | 2 +- .../com/google/ai/client/generativeai/internal/api/Response.kt | 2 +- .../google/ai/client/generativeai/internal/api/client/Types.kt | 2 +- .../google/ai/client/generativeai/internal/api/server/Types.kt | 2 +- .../google/ai/client/generativeai/internal/api/shared/Types.kt | 2 +- .../google/ai/client/generativeai/internal/util/conversions.kt | 2 +- .../com/google/ai/client/generativeai/internal/util/kotlin.kt | 2 +- .../com/google/ai/client/generativeai/internal/util/ktor.kt | 2 +- .../ai/client/generativeai/internal/util/serialization.kt | 2 +- .../java/com/google/ai/client/generativeai/java/ChatFutures.kt | 2 +- .../ai/client/generativeai/java/GenerativeModelFutures.kt | 2 +- .../com/google/ai/client/generativeai/type/BlockThreshold.kt | 2 +- .../java/com/google/ai/client/generativeai/type/Candidate.kt | 2 +- .../main/java/com/google/ai/client/generativeai/type/Content.kt | 2 +- .../google/ai/client/generativeai/type/CountTokensResponse.kt | 2 +- .../java/com/google/ai/client/generativeai/type/Exceptions.kt | 2 +- .../google/ai/client/generativeai/type/FunctionDeclarations.kt | 2 +- .../com/google/ai/client/generativeai/type/FunctionParameter.kt | 2 +- .../ai/client/generativeai/type/GenerateContentResponse.kt | 2 +- .../com/google/ai/client/generativeai/type/GenerationConfig.kt | 2 +- .../com/google/ai/client/generativeai/type/GenerativeBeta.kt | 2 +- .../java/com/google/ai/client/generativeai/type/HarmCategory.kt | 2 +- .../com/google/ai/client/generativeai/type/HarmProbability.kt | 2 +- .../main/java/com/google/ai/client/generativeai/type/Part.kt | 2 +- .../com/google/ai/client/generativeai/type/PromptFeedback.kt | 2 +- .../com/google/ai/client/generativeai/type/SafetySetting.kt | 2 +- .../main/java/com/google/ai/client/generativeai/type/Tool.kt | 2 +- .../main/java/com/google/ai/client/generativeai/type/Type.kt | 2 +- .../com/google/ai/client/generativeai/GenerativeModelTests.kt | 2 +- .../com/google/ai/client/generativeai/StreamingSnapshotTests.kt | 2 +- .../com/google/ai/client/generativeai/UnarySnapshotTests.kt | 2 +- .../test/java/com/google/ai/client/generativeai/util/kotlin.kt | 2 +- .../test/java/com/google/ai/client/generativeai/util/tests.kt | 2 +- plugins/build.gradle.kts | 2 +- plugins/settings.gradle.kts | 2 +- plugins/src/main/java/com/google/gradle/plugins/ApiPlugin.kt | 2 +- .../src/main/java/com/google/gradle/plugins/ChangelogPlugin.kt | 2 +- .../src/main/java/com/google/gradle/plugins/LicensePlugin.kt | 2 +- .../src/main/java/com/google/gradle/tasks/ApplyLicenseTask.kt | 2 +- .../src/main/java/com/google/gradle/tasks/FindChangesTask.kt | 2 +- plugins/src/main/java/com/google/gradle/tasks/MakeChangeTask.kt | 2 +- .../main/java/com/google/gradle/tasks/MakeReleaseNotesTask.kt | 2 +- .../main/java/com/google/gradle/tasks/ValidateLicenseTask.kt | 2 +- .../java/com/google/gradle/tasks/WarnAboutApiChangesTask.kt | 2 +- plugins/src/main/java/com/google/gradle/types/Changelog.kt | 2 +- .../src/main/java/com/google/gradle/types/LicenseTemplate.kt | 2 +- plugins/src/main/java/com/google/gradle/types/LinesChanged.kt | 2 +- plugins/src/main/java/com/google/gradle/types/ModuleVersion.kt | 2 +- .../main/java/com/google/gradle/types/RandomWordsGenerator.kt | 2 +- plugins/src/main/java/com/google/gradle/util/gradle.kt | 2 +- plugins/src/main/java/com/google/gradle/util/kotlin.kt | 2 +- settings.gradle.kts | 2 +- 95 files changed, 95 insertions(+), 95 deletions(-) diff --git a/.github/hooks/pre-push b/.github/hooks/pre-push index 8c91da69..98e51648 100644 --- a/.github/hooks/pre-push +++ b/.github/hooks/pre-push @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/.github/hooks/pre-push.d/formatting.sh b/.github/hooks/pre-push.d/formatting.sh index 318e1f66..62ae6bc6 100755 --- a/.github/hooks/pre-push.d/formatting.sh +++ b/.github/hooks/pre-push.d/formatting.sh @@ -1,5 +1,5 @@ #!/bin/sh -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/.github/hooks/pre-push.d/licensing.sh b/.github/hooks/pre-push.d/licensing.sh index fa28dc9f..7c21e439 100755 --- a/.github/hooks/pre-push.d/licensing.sh +++ b/.github/hooks/pre-push.d/licensing.sh @@ -1,5 +1,5 @@ #!/bin/sh -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/INSTALL_HOOKS.sh b/INSTALL_HOOKS.sh index 3fedaf0d..88de7580 100755 --- a/INSTALL_HOOKS.sh +++ b/INSTALL_HOOKS.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/build.gradle.kts b/build.gradle.kts index 2c30b47e..dd76f23e 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -1,5 +1,5 @@ /* - * Copyright 2023 Google LLC + * Copyright 2024 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/change b/change index eff4bc35..40900da7 100755 --- a/change +++ b/change @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/generativeai-android-sample/.google/packaging.yaml b/generativeai-android-sample/.google/packaging.yaml index c1e03664..9009f593 100644 --- a/generativeai-android-sample/.google/packaging.yaml +++ b/generativeai-android-sample/.google/packaging.yaml @@ -1,4 +1,4 @@ -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/generativeai-android-sample/app/build.gradle.kts b/generativeai-android-sample/app/build.gradle.kts index 9e73b16f..78c15b58 100644 --- a/generativeai-android-sample/app/build.gradle.kts +++ b/generativeai-android-sample/app/build.gradle.kts @@ -1,5 +1,5 @@ /* - * Copyright 2023 Google LLC + * Copyright 2024 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/generativeai-android-sample/app/src/main/AndroidManifest.xml b/generativeai-android-sample/app/src/main/AndroidManifest.xml index 0d795b20..3e3e235b 100644 --- a/generativeai-android-sample/app/src/main/AndroidManifest.xml +++ b/generativeai-android-sample/app/src/main/AndroidManifest.xml @@ -1,5 +1,5 @@ -