From 58e46ab7d3041f8205211bad216730346d31f394 Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Wed, 20 Mar 2024 18:04:27 -0700 Subject: [PATCH 1/6] split the auto function calling changes between common and generativeai --- .../ai/client/generativeai/common/Request.kt | 3 + .../generativeai/common/RequestOptions.kt | 10 +- .../generativeai/common/client/Types.kt | 20 ++ .../generativeai/common/shared/Types.kt | 12 + generativeai/build.gradle.kts | 1 + .../com/google/ai/client/generativeai/Chat.kt | 133 +++++--- .../ai/client/generativeai/GenerativeModel.kt | 52 +++- .../generativeai/internal/util/conversions.kt | 58 +++- .../generativeai/type/FunctionDeclarations.kt | 291 ++++++++++++++++++ .../generativeai/type/FunctionParameter.kt | 19 ++ .../type/GenerateContentResponse.kt | 6 + .../generativeai/type/GenerativeBeta.kt | 22 ++ .../ai/client/generativeai/type/Part.kt | 7 + .../generativeai/type/RequestOptions.kt | 16 +- .../ai/client/generativeai/type/Tool.kt | 28 ++ .../ai/client/generativeai/type/Type.kt | 42 +++ 16 files changed, 674 insertions(+), 46 deletions(-) create mode 100644 generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt create mode 100644 generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt create mode 100644 generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerativeBeta.kt create mode 100644 generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt create mode 100644 generativeai/src/main/java/com/google/ai/client/generativeai/type/Type.kt diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt index 89afce30..bdb9ec13 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt @@ -17,8 +17,10 @@ package com.google.ai.client.generativeai.common import com.google.ai.client.generativeai.common.client.GenerationConfig +import com.google.ai.client.generativeai.common.client.Tool import com.google.ai.client.generativeai.common.shared.Content import com.google.ai.client.generativeai.common.shared.SafetySetting + import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable @@ -30,6 +32,7 @@ data class GenerateContentRequest( val contents: List, @SerialName("safety_settings") val safetySettings: List? = null, @SerialName("generation_config") val generationConfig: GenerationConfig? = null, + val tools: List? = null ) : Request @Serializable diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/RequestOptions.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/RequestOptions.kt index 4decb5c1..3ca7455d 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/RequestOptions.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/RequestOptions.kt @@ -27,20 +27,24 @@ import kotlin.time.toDuration * @property timeout the maximum amount of time for a request to take, from the first request to * first response. * @property apiVersion the api endpoint to call. + * @property disableAutoFunction if true, auto functions will not be automatically executed */ class RequestOptions( val timeout: Duration, val apiVersion: String = "v1", - val endpoint: String = "https://generativelanguage.googleapis.com" + val disableAutoFunction: Boolean = false, + val endpoint: String = "https://generativelanguage.googleapis.com", ) { @JvmOverloads constructor( timeout: Long? = HttpTimeout.INFINITE_TIMEOUT_MS, apiVersion: String = "v1", - endpoint: String = "https://generativelanguage.googleapis.com" + disableAutoFunction: Boolean = false, + endpoint: String = "https://generativelanguage.googleapis.com", ) : this( (timeout ?: HttpTimeout.INFINITE_TIMEOUT_MS).toDuration(DurationUnit.MILLISECONDS), apiVersion, - endpoint + disableAutoFunction, + endpoint, ) } diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt index bacbb1f1..41507836 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt @@ -28,3 +28,23 @@ data class GenerationConfig( @SerialName("max_output_tokens") val maxOutputTokens: Int?, @SerialName("stop_sequences") val stopSequences: List? ) + +@Serializable data class Tool(val functionDeclarations: List) + +@Serializable +data class FunctionDeclaration( + val name: String, + val description: String, + val parameters: FunctionParameterProperties, +) + +@Serializable +data class FunctionParameterProperties( + val type: String, + val description: String? = null, + val format: String? = null, + val enum: List? = null, + val properties: Map? = null, + val required: List? = null, + val items: FunctionParameterProperties? = null +) \ No newline at end of file diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt index 8fc641be..7fd246c7 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt @@ -26,6 +26,7 @@ import kotlinx.serialization.Serializable import kotlinx.serialization.SerializationException import kotlinx.serialization.json.JsonContentPolymorphicSerializer import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.jsonObject object HarmCategorySerializer : @@ -52,12 +53,21 @@ data class Content(@EncodeDefault val role: String? = "user", val parts: List) + @Serializable data class SafetySetting(val category: HarmCategory, val threshold: HarmBlockThreshold) @@ -76,6 +86,8 @@ object PartSerializer : JsonContentPolymorphicSerializer(Part::class) { return when { "text" in jsonObject -> TextPart.serializer() "inlineData" in jsonObject -> BlobPart.serializer() + "functionCall" in jsonObject -> FunctionCallPart.serializer() + "functionResponse" in jsonObject -> FunctionResponsePart.serializer() else -> throw SerializationException("Unknown Part type") } } diff --git a/generativeai/build.gradle.kts b/generativeai/build.gradle.kts index bac2aefa..ac9a9776 100644 --- a/generativeai/build.gradle.kts +++ b/generativeai/build.gradle.kts @@ -79,6 +79,7 @@ dependencies { implementation("org.slf4j:slf4j-nop:2.0.9") implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.7.3") implementation("org.jetbrains.kotlinx:kotlinx-coroutines-reactive:1.7.3") + implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1") implementation("org.reactivestreams:reactive-streams:1.0.3") implementation("com.google.guava:listenablefuture:1.0") diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt index b536c244..f85eab89 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt @@ -19,6 +19,8 @@ package com.google.ai.client.generativeai import android.graphics.Bitmap import com.google.ai.client.generativeai.type.BlobPart import com.google.ai.client.generativeai.type.Content +import com.google.ai.client.generativeai.type.FunctionCallPart +import com.google.ai.client.generativeai.type.FunctionResponsePart import com.google.ai.client.generativeai.type.GenerateContentResponse import com.google.ai.client.generativeai.type.ImagePart import com.google.ai.client.generativeai.type.InvalidStateException @@ -27,8 +29,9 @@ import com.google.ai.client.generativeai.type.content import java.util.LinkedList import java.util.concurrent.Semaphore import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.FlowCollector import kotlinx.coroutines.flow.onCompletion -import kotlinx.coroutines.flow.onEach +import kotlinx.coroutines.flow.transform /** * Representation of a back and forth interaction with a model. @@ -44,6 +47,10 @@ import kotlinx.coroutines.flow.onEach */ class Chat(private val model: GenerativeModel, val history: MutableList = ArrayList()) { private var lock = Semaphore(1) + companion object{ + private val VALID_ROLES = listOf("user", "function") + } + /** * Generates a response from the backend with the provided [Content], and any previous ones @@ -53,13 +60,27 @@ class Chat(private val model: GenerativeModel, val history: MutableList * @throws InvalidStateException if the prompt is not coming from the 'user' role * @throws InvalidStateException if the [Chat] instance has an active request. */ - suspend fun sendMessage(prompt: Content): GenerateContentResponse { - prompt.assertComesFromUser() + suspend fun sendMessage(inputPrompt: Content): GenerateContentResponse { + inputPrompt.assertComesFromUser() attemptLock() + var response: GenerateContentResponse + var prompt = inputPrompt + val tempHistory = LinkedList() try { - val response = model.generateContent(*history.toTypedArray(), prompt) - history.add(prompt) - history.add(response.candidates.first().content) + while (true) { + response = + model.generateContent(*history.toTypedArray(), *tempHistory.toTypedArray(), prompt) + val responsePart = response.candidates.first().content.parts.first() + + tempHistory.add(prompt) + tempHistory.add(response.candidates.first().content) + if (responsePart !is FunctionCallPart) break + if (model.requestOptions.disableAutoFunction) break + + val output = model.executeFunction(responsePart) + prompt = Content("function", listOf(FunctionResponsePart(responsePart.name, output))) + } + history.addAll(tempHistory) return response } finally { lock.release() @@ -101,43 +122,19 @@ class Chat(private val model: GenerativeModel, val history: MutableList attemptLock() val flow = model.generateContentStream(*history.toTypedArray(), prompt) - val bitmaps = LinkedList() - val blobs = LinkedList() - val text = StringBuilder() - + val tempHistory = mutableListOf() + tempHistory.add(prompt) /** * TODO: revisit when images and blobs are returned. This will cause issues with how things are * structured in the response. eg; a text/image/text response will be (incorrectly) * represented as image/text */ return flow - .onEach { - for (part in it.candidates.first().content.parts) { - when (part) { - is TextPart -> text.append(part.text) - is ImagePart -> bitmaps.add(part.image) - is BlobPart -> blobs.add(part) - } - } - } + .transform { response -> automaticFunctionExecutingTransform(this, tempHistory, response) } .onCompletion { lock.release() if (it == null) { - val content = - content("model") { - for (bitmap in bitmaps) { - image(bitmap) - } - for (blob in blobs) { - blob(blob.mimeType, blob.blob) - } - if (text.isNotBlank()) { - text(text.toString()) - } - } - - history.add(prompt) - history.add(content) + history.addAll(tempHistory) } } } @@ -167,9 +164,73 @@ class Chat(private val model: GenerativeModel, val history: MutableList } private fun Content.assertComesFromUser() { - if (role != "user") { - throw InvalidStateException("Chat prompts should come from the 'user' role.") + if (!VALID_ROLES.contains(role)) { + throw InvalidStateException("Chat prompts should come from the 'user' or 'function' role.") + } + } + + private suspend fun automaticFunctionExecutingTransform( + transformer: FlowCollector, + tempHistory: MutableList, + response: GenerateContentResponse, + ) { + for (part in response.candidates.first().content.parts) { + when (part) { + is TextPart -> { + transformer.emit(response) + addTextToHistory(tempHistory, part) + } + is ImagePart -> { + transformer.emit(response) + tempHistory.add(Content("model", listOf(part))) + } + is BlobPart -> { + transformer.emit(response) + tempHistory.add(Content("model", listOf(part))) + } + is FunctionCallPart -> { + if (model.requestOptions.disableAutoFunction) { + tempHistory.add(response.candidates.first().content) + continue + } + val functionCall = + response.candidates.first().content.parts.first { it is FunctionCallPart } + as FunctionCallPart + val output = model.executeFunction(functionCall) + val functionResponse = + Content("function", listOf(FunctionResponsePart(functionCall.name, output))) + tempHistory.add(response.candidates.first().content) + tempHistory.add(functionResponse) + model + .generateContentStream(*history.toTypedArray(), *tempHistory.toTypedArray()) + .collect { automaticFunctionExecutingTransform(transformer, tempHistory, it) } + } + } + } + } + + private fun addTextToHistory(tempHistory: MutableList, textPart: TextPart) { + val lastContent = tempHistory.lastOrNull() + if (lastContent?.role == "model" && lastContent.parts.any { it is TextPart }) { + tempHistory.removeLast() + val editedContent = + Content( + "model", + lastContent.parts.map { + when (it) { + is TextPart -> { + TextPart(it.text + textPart.text) + } + else -> { + it + } + } + }, + ) + tempHistory.add(editedContent) + return } + tempHistory.add(Content("model", listOf(textPart))) } private fun attemptLock() { diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt index 690f5be4..69b29efb 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt @@ -25,18 +25,28 @@ import com.google.ai.client.generativeai.internal.util.toPublic import com.google.ai.client.generativeai.type.Content import com.google.ai.client.generativeai.type.CountTokensResponse import com.google.ai.client.generativeai.type.FinishReason +import com.google.ai.client.generativeai.type.FourParameterFunction +import com.google.ai.client.generativeai.type.FunctionCallPart import com.google.ai.client.generativeai.type.GenerateContentResponse import com.google.ai.client.generativeai.type.GenerationConfig +import com.google.ai.client.generativeai.type.GenerativeBeta import com.google.ai.client.generativeai.type.GoogleGenerativeAIException +import com.google.ai.client.generativeai.type.NoParameterFunction +import com.google.ai.client.generativeai.type.OneParameterFunction import com.google.ai.client.generativeai.type.PromptBlockedException import com.google.ai.client.generativeai.type.RequestOptions import com.google.ai.client.generativeai.type.ResponseStoppedException import com.google.ai.client.generativeai.type.SafetySetting import com.google.ai.client.generativeai.type.SerializationException +import com.google.ai.client.generativeai.type.ThreeParameterFunction +import com.google.ai.client.generativeai.type.Tool +import com.google.ai.client.generativeai.type.TwoParameterFunction import com.google.ai.client.generativeai.type.content import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.catch import kotlinx.coroutines.flow.map +import kotlinx.serialization.ExperimentalSerializationApi +import org.json.JSONObject /** * A facilitator for a given multimodal model (eg; Gemini). @@ -48,14 +58,16 @@ import kotlinx.coroutines.flow.map * generation * @property requestOptions configuration options to utilize during backend communication */ +@OptIn(GenerativeBeta::class, ExperimentalSerializationApi::class) class GenerativeModel internal constructor( val modelName: String, val apiKey: String, val generationConfig: GenerationConfig? = null, val safetySettings: List? = null, + val tools: List? = null, val requestOptions: RequestOptions = RequestOptions(), - private val controller: APIController + private val controller: APIController, ) { @JvmOverloads @@ -64,14 +76,16 @@ internal constructor( apiKey: String, generationConfig: GenerationConfig? = null, safetySettings: List? = null, + tools: List? = null, requestOptions: RequestOptions = RequestOptions(), ) : this( modelName, apiKey, generationConfig, safetySettings, + tools, requestOptions, - APIController(apiKey, modelName, requestOptions.toInternal()) + APIController(apiKey, modelName, requestOptions.toInternal()), ) /** @@ -171,12 +185,44 @@ internal constructor( return countTokens(content { image(prompt) }) } + /** + * Executes a function request by the model. + * + * @param functionCallPart A [FunctionCallPart] from the model, containing a function call and + * parameters + * @return The output of the requested function call + */ + suspend fun executeFunction(functionCallPart: FunctionCallPart): JSONObject { + if (tools == null) { + throw RuntimeException("No registered tools") + } + val tool = tools.first { it.functionDeclarations.any { it.name == functionCallPart.name } } + val callable = + tool.functionDeclarations.firstOrNull() { it.name == functionCallPart.name } + ?: throw RuntimeException("No registered function named ${functionCallPart.name}") + return when (callable) { + is NoParameterFunction -> callable.execute() + is OneParameterFunction<*> -> + (callable as OneParameterFunction).execute(functionCallPart) + is TwoParameterFunction<*, *> -> + (callable as TwoParameterFunction).execute(functionCallPart) + is ThreeParameterFunction<*, *, *> -> + (callable as ThreeParameterFunction).execute(functionCallPart) + is FourParameterFunction<*, *, *, *> -> + (callable as FourParameterFunction).execute(functionCallPart) + else -> { + throw RuntimeException("UNREACHABLE") + } + } + } + private fun constructRequest(vararg prompt: Content) = GenerateContentRequest( modelName, prompt.map { it.toInternal() }, safetySettings?.map { it.toInternal() }, - generationConfig?.toInternal() + generationConfig?.toInternal(), + tools?.map { it.toInternal() }, ) private fun constructCountTokensRequest(vararg prompt: Content) = diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt index e4b236b7..5ce1f7e1 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt @@ -22,6 +22,7 @@ import android.util.Base64 import com.google.ai.client.generativeai.common.CountTokensResponse import com.google.ai.client.generativeai.common.GenerateContentResponse import com.google.ai.client.generativeai.common.RequestOptions +import com.google.ai.client.generativeai.common.client.FunctionParameterProperties import com.google.ai.client.generativeai.common.client.GenerationConfig import com.google.ai.client.generativeai.common.server.BlockReason import com.google.ai.client.generativeai.common.server.Candidate @@ -33,6 +34,10 @@ import com.google.ai.client.generativeai.common.server.SafetyRating import com.google.ai.client.generativeai.common.shared.Blob import com.google.ai.client.generativeai.common.shared.BlobPart import com.google.ai.client.generativeai.common.shared.Content +import com.google.ai.client.generativeai.common.shared.FunctionCall +import com.google.ai.client.generativeai.common.shared.FunctionCallPart +import com.google.ai.client.generativeai.common.shared.FunctionResponse +import com.google.ai.client.generativeai.common.shared.FunctionResponsePart import com.google.ai.client.generativeai.common.shared.HarmBlockThreshold import com.google.ai.client.generativeai.common.shared.HarmCategory import com.google.ai.client.generativeai.common.shared.Part @@ -40,15 +45,23 @@ import com.google.ai.client.generativeai.common.shared.SafetySetting import com.google.ai.client.generativeai.common.shared.TextPart import com.google.ai.client.generativeai.type.BlockThreshold import com.google.ai.client.generativeai.type.CitationMetadata +import com.google.ai.client.generativeai.type.FunctionDeclaration +import com.google.ai.client.generativeai.type.GenerativeBeta import com.google.ai.client.generativeai.type.ImagePart +import com.google.ai.client.generativeai.type.ParameterDeclaration import com.google.ai.client.generativeai.type.SerializationException +import com.google.ai.client.generativeai.type.Tool import com.google.ai.client.generativeai.type.content import java.io.ByteArrayOutputStream +import kotlinx.serialization.decodeFromString +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonObject +import org.json.JSONObject private const val BASE_64_FLAGS = Base64.NO_WRAP internal fun com.google.ai.client.generativeai.type.RequestOptions.toInternal() = - RequestOptions(timeout, apiVersion) + RequestOptions(timeout, apiVersion, disableAutoFunction) internal fun com.google.ai.client.generativeai.type.Content.toInternal() = Content(this.role, this.parts.map { it.toInternal() }) @@ -99,6 +112,37 @@ internal fun BlockThreshold.toInternal() = BlockThreshold.UNSPECIFIED -> HarmBlockThreshold.UNSPECIFIED } +@GenerativeBeta +internal fun Tool.toInternal() = + com.google.ai.client.generativeai.common.client.Tool( + functionDeclarations.map { it.toInternal() } + ) + +@GenerativeBeta +internal fun FunctionDeclaration.toInternal() = + com.google.ai.client.generativeai.common.client.FunctionDeclaration( + name, + description, + FunctionParameterProperties( + properties = getParameters().associate { it.name to it.toInternal() }, + required = getParameters().map { it.name }, + type = "OBJECT", + ), + ) + +internal fun ParameterDeclaration.toInternal(): FunctionParameterProperties = + FunctionParameterProperties( + type.name, + description, + format, + enum, + properties?.mapValues { it.value.toInternal() }, + required, + items?.toInternal() + ) + +internal fun JSONObject.toInternal() = Json.decodeFromString(toString()) + internal fun Candidate.toPublic(): com.google.ai.client.generativeai.type.Candidate { val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty() val citations = citationMetadata?.citationSources?.map { it.toPublic() }.orEmpty() @@ -126,6 +170,16 @@ internal fun Part.toPublic(): com.google.ai.client.generativeai.type.Part { com.google.ai.client.generativeai.type.BlobPart(inlineData.mimeType, data) } } + is FunctionCallPart -> + com.google.ai.client.generativeai.type.FunctionCallPart( + functionCall.name, + functionCall.args.orEmpty(), + ) + is FunctionResponsePart -> + com.google.ai.client.generativeai.type.FunctionResponsePart( + functionResponse.name, + functionResponse.response.toPublic(), + ) } } @@ -194,6 +248,8 @@ internal fun GenerateContentResponse.toPublic() = internal fun CountTokensResponse.toPublic() = com.google.ai.client.generativeai.type.CountTokensResponse(totalTokens) +internal fun JsonObject.toPublic() = JSONObject(toString()) + private fun encodeBitmapToBase64Png(input: Bitmap): String { ByteArrayOutputStream().let { input.compress(Bitmap.CompressFormat.JPEG, 80, it) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt new file mode 100644 index 00000000..227e40f9 --- /dev/null +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt @@ -0,0 +1,291 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.type + +import org.json.JSONObject + +/** + * A declared function, including implementation, that a model can be given access to in order to + * gain info or complete tasks. + * + * @property name The name of the function call, this should be clear and descriptive for the model + * @property description A description of what the function does and its output. + * @property function the function implementation + */ +@GenerativeBeta +class NoParameterFunction( + name: String, + description: String, + val function: suspend () -> JSONObject, +) : FunctionDeclaration(name, description) { + override fun getParameters() = listOf>() + + suspend fun execute() = function() + + override suspend fun execute(part: FunctionCallPart) = function() +} + +/** + * A declared function, including implementation, that a model can be given access to in order to + * gain info or complete tasks. + * + * @property name The name of the function call, this should be clear and descriptive for the model + * @property description A description of what the function does and its output. + * @property param A description of the first function parameter + * @property function the function implementation + */ +@GenerativeBeta +class OneParameterFunction( + name: String, + description: String, + val param: ParameterDeclaration, + val function: suspend (T) -> JSONObject, +) : FunctionDeclaration(name, description) { + override fun getParameters() = listOf(param) + + override suspend fun execute(part: FunctionCallPart): JSONObject { + val arg1 = part.getArgOrThrow(param) + return function(arg1) + } +} + +/** + * A declared function, including implementation, that a model can be given access to in order to + * gain info or complete tasks. + * + * @property name The name of the function call, this should be clear and descriptive for the model + * @property description A description of what the function does and its output. + * @property param1 A description of the first function parameter + * @property param2 A description of the second function parameter + * @property function the function implementation + */ +@GenerativeBeta +class TwoParameterFunction( + name: String, + description: String, + val param1: ParameterDeclaration, + val param2: ParameterDeclaration, + val function: suspend (T, U) -> JSONObject, +) : FunctionDeclaration(name, description) { + override fun getParameters() = listOf(param1, param2) + + override suspend fun execute(part: FunctionCallPart): JSONObject { + val arg1 = part.getArgOrThrow(param1) + val arg2 = part.getArgOrThrow(param2) + return function(arg1, arg2) + } +} + +/** + * A declared function, including implementation, that a model can be given access to in order to + * gain info or complete tasks. + * + * @property name The name of the function call, this should be clear and descriptive for the model + * @property description A description of what the function does and its output. + * @property param1 A description of the first function parameter + * @property param2 A description of the second function parameter + * @property param3 A description of the third function parameter + * @property function the function implementation + */ +@GenerativeBeta +class ThreeParameterFunction( + name: String, + description: String, + val param1: ParameterDeclaration, + val param2: ParameterDeclaration, + val param3: ParameterDeclaration, + val function: suspend (T, U, V) -> JSONObject, +) : FunctionDeclaration(name, description) { + override fun getParameters() = listOf(param1, param2, param3) + + override suspend fun execute(part: FunctionCallPart): JSONObject { + val arg1 = part.getArgOrThrow(param1) + val arg2 = part.getArgOrThrow(param2) + val arg3 = part.getArgOrThrow(param3) + return function(arg1, arg2, arg3) + } +} + +/** + * A declared function, including implementation, that a model can be given access to in order to + * gain info or complete tasks. + * + * @property name The name of the function call, this should be clear and descriptive for the model + * @property description A description of what the function does and its output. + * @property param1 A description of the first function parameter + * @property param2 A description of the second function parameter + * @property param3 A description of the third function parameter + * @property param4 A description of the fourth function parameter + * @property function the function implementation + */ +@GenerativeBeta +class FourParameterFunction( + name: String, + description: String, + val param1: ParameterDeclaration, + val param2: ParameterDeclaration, + val param3: ParameterDeclaration, + val param4: ParameterDeclaration, + val function: suspend (T, U, V, W) -> JSONObject, +) : FunctionDeclaration(name, description) { + override fun getParameters() = listOf(param1, param2, param3, param4) + + override suspend fun execute(part: FunctionCallPart): JSONObject { + val arg1 = part.getArgOrThrow(param1) + val arg2 = part.getArgOrThrow(param2) + val arg3 = part.getArgOrThrow(param3) + val arg4 = part.getArgOrThrow(param4) + return function(arg1, arg2, arg3, arg4) + } +} + +@GenerativeBeta +abstract class FunctionDeclaration(val name: String, val description: String) { + abstract fun getParameters(): List> + + abstract suspend fun execute(part: FunctionCallPart): JSONObject +} + +/** + * Represents a parameter for a declared function + * + * @property name: The name of the parameter + * @property description: The description of what the parameter should contain or represent + * @property format: format information for the parameter, this can include bitlength in the case of + * int/float or keywords like "enum" for the string type + * @property enum: contains the enum values for a string enum + * @property type: contains the type info and parser + * @property properties: if type is OBJECT, then this contains the description of the fields of the + * object by name + * @property required: if type is OBJECT, then this contains the list of required keys + * @property items: if the type is ARRAY, then this contains a description of the objects in the + * array + */ +class ParameterDeclaration( + val name: String, + val description: String, + val format: String? = null, + val enum: List? = null, + val properties: Map>? = null, + val required: List? = null, + val items: ParameterDeclaration? = null, + val type: FunctionType, +) { + fun fromString(value: String?) = type.parse(value) + + companion object { + fun int(name: String, description: String) = + ParameterDeclaration( + name = name, + description = description, + type = FunctionType.INTEGER, + ) + + fun str(name: String, description: String) = + ParameterDeclaration( + name = name, + description = description, + type = FunctionType.STRING, + ) + + fun bool(name: String, description: String) = + ParameterDeclaration( + name = name, + description = description, + type = FunctionType.BOOLEAN, + ) + + fun num(name: String, description: String) = + ParameterDeclaration( + name = name, + description = description, + type = FunctionType.NUMBER, + ) + + fun obj(name: String, description: String) = + ParameterDeclaration( + name = name, + description = description, + type = FunctionType.OBJECT, + ) + + fun arr(name: String, description: String) = + ParameterDeclaration>( + name = name, + description = description, + type = FunctionType.ARRAY, + ) + + fun enum(name: String, description: String, values: List) = + ParameterDeclaration( + name = name, + description = description, + format = "enum", + enum = values, + type = FunctionType.STRING, + ) + } +} + +@GenerativeBeta +fun defineFunction(name: String, description: String, function: suspend () -> JSONObject) = + NoParameterFunction(name, description, function) + +@GenerativeBeta +fun defineFunction( + name: String, + description: String, + arg1: ParameterDeclaration, + function: suspend (T) -> JSONObject, +) = OneParameterFunction(name, description, arg1, function) + +@GenerativeBeta +fun defineFunction( + name: String, + description: String, + arg1: ParameterDeclaration, + arg2: ParameterDeclaration, + function: suspend (T, U) -> JSONObject, +) = TwoParameterFunction(name, description, arg1, arg2, function) + +@GenerativeBeta +fun defineFunction( + name: String, + description: String, + arg1: ParameterDeclaration, + arg2: ParameterDeclaration, + arg3: ParameterDeclaration, + function: suspend (T, U, W) -> JSONObject, +) = ThreeParameterFunction(name, description, arg1, arg2, arg3, function) + +@GenerativeBeta +fun defineFunction( + name: String, + description: String, + arg1: ParameterDeclaration, + arg2: ParameterDeclaration, + arg3: ParameterDeclaration, + arg4: ParameterDeclaration, + function: suspend (T, U, W, Z) -> JSONObject, +) = FourParameterFunction(name, description, arg1, arg2, arg3, arg4, function) + +private fun FunctionCallPart.getArgOrThrow(param: ParameterDeclaration): T { + return param.fromString(args[param.name]) + ?: throw RuntimeException( + "Missing argument for parameter \"${param.name}\" for function \"$name\"" + ) +} \ No newline at end of file diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt new file mode 100644 index 00000000..614f0eba --- /dev/null +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt @@ -0,0 +1,19 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.type + +class FunctionParameter(val name: String, val description: String, val type: FunctionType) {} \ No newline at end of file diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerateContentResponse.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerateContentResponse.kt index 9e33ce16..9fe30329 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerateContentResponse.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerateContentResponse.kt @@ -32,6 +32,12 @@ class GenerateContentResponse( /** Convenience field representing the first text part in the response, if it exists. */ val text: String? by lazy { firstPartAs()?.text } + /** Convenience field representing the first text part in the response, if it exists. */ + val functionCall: FunctionCallPart? by lazy { firstPartAs() } + + /** Convenience field representing the first text part in the response, if it exists. */ + val functionResponse: FunctionResponsePart? by lazy { firstPartAs() } + private inline fun firstPartAs(): T? { if (candidates.isEmpty()) { warn("No candidates were found, but was asked to get a candidate.") diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerativeBeta.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerativeBeta.kt new file mode 100644 index 00000000..be0f0dd6 --- /dev/null +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerativeBeta.kt @@ -0,0 +1,22 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.type + +@RequiresOptIn(message = "This API is only available on the v1beta endpoint") +@Retention(AnnotationRetention.BINARY) +@Target(AnnotationTarget.CLASS, AnnotationTarget.FUNCTION) +annotation class GenerativeBeta \ No newline at end of file diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt index 4c54cc43..83f22953 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt @@ -17,6 +17,7 @@ package com.google.ai.client.generativeai.type import android.graphics.Bitmap +import org.json.JSONObject /** * Interface representing data sent to and received from requests. @@ -40,6 +41,12 @@ class ImagePart(val image: Bitmap) : Part /** Represents binary data with an associated MIME type sent to and received from requests. */ class BlobPart(val mimeType: String, val blob: ByteArray) : Part +/** Represents function call name and params received from requests. */ +class FunctionCallPart(val name: String, val args: Map) : Part + +/** Represents function call output to be returned to the model when it requests a function call */ +class FunctionResponsePart(val name: String, val response: JSONObject) : Part + /** @return The part as a [String] if it represents text, and null otherwise */ fun Part.asTextOrNull(): String? = (this as? TextPart)?.text diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt index e25561fb..147a9895 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt @@ -26,11 +26,21 @@ import kotlin.time.toDuration * @property timeout the maximum amount of time for a request to take, from the first request to * first response. * @property apiVersion the api endpoint to call. + * @property disableAutoFunction if true, auto functions will not be automatically executed */ -class RequestOptions(val timeout: Duration, val apiVersion: String = "v1") { +class RequestOptions( + val timeout: Duration, + val apiVersion: String = "v1", + val disableAutoFunction: Boolean = false, +) { @JvmOverloads constructor( timeout: Long? = Long.MAX_VALUE, - apiVersion: String = "v1" - ) : this((timeout ?: Long.MAX_VALUE).toDuration(DurationUnit.MILLISECONDS), apiVersion) + apiVersion: String = "v1", + disableAutoFunction: Boolean = false, + ) : this( + (timeout ?: Long.MAX_VALUE).toDuration(DurationUnit.MILLISECONDS), + apiVersion, + disableAutoFunction, + ) } diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt new file mode 100644 index 00000000..ecccff2a --- /dev/null +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt @@ -0,0 +1,28 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.type + +/** + * Contains a set of function declarations that the model has access to. These can be used to gather + * information, or complete tasks + * + * @param functionDeclarations The set of functions that this tool allows the model access to + */ +@OptIn(GenerativeBeta::class) +class Tool( + val functionDeclarations: List, +) \ No newline at end of file diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Type.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Type.kt new file mode 100644 index 00000000..a2c236db --- /dev/null +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Type.kt @@ -0,0 +1,42 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.type + +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.jsonArray +import org.json.JSONObject + +/** + * Represents and passes the type information for an automated function call. + * + * @property name: the enum name of the type + * @property parse: the deserialization function + * @property T: the type of the object that this maps to in code. + */ +class FunctionType(val name: String, val parse: (String?) -> T?) { + companion object { + val STRING = FunctionType("STRING") { it } + val INTEGER = FunctionType("INTEGER") { it?.toLongOrNull() } + val NUMBER = FunctionType("NUMBER") { it?.toDoubleOrNull() } + val BOOLEAN = FunctionType("BOOLEAN") { it?.toBoolean() } + val ARRAY = + FunctionType>("ARRAY") { it -> + it?.let { Json.parseToJsonElement(it).jsonArray.map { element -> element.toString() } } + } + val OBJECT = FunctionType("OBJECT") { it?.let { JSONObject(it) } } + } +} \ No newline at end of file From 648887190a2332bd17be43682a229d3523fd18b3 Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Wed, 20 Mar 2024 18:07:50 -0700 Subject: [PATCH 2/6] ktfmt common --- .../com/google/ai/client/generativeai/common/Request.kt | 3 +-- .../google/ai/client/generativeai/common/client/Types.kt | 6 +++--- .../google/ai/client/generativeai/common/shared/Types.kt | 9 ++------- 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt index bdb9ec13..4e484266 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt @@ -20,7 +20,6 @@ import com.google.ai.client.generativeai.common.client.GenerationConfig import com.google.ai.client.generativeai.common.client.Tool import com.google.ai.client.generativeai.common.shared.Content import com.google.ai.client.generativeai.common.shared.SafetySetting - import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable @@ -32,7 +31,7 @@ data class GenerateContentRequest( val contents: List, @SerialName("safety_settings") val safetySettings: List? = null, @SerialName("generation_config") val generationConfig: GenerationConfig? = null, - val tools: List? = null + val tools: List? = null, ) : Request @Serializable diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt index 41507836..0a7335ce 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt @@ -26,7 +26,7 @@ data class GenerationConfig( @SerialName("top_k") val topK: Int?, @SerialName("candidate_count") val candidateCount: Int?, @SerialName("max_output_tokens") val maxOutputTokens: Int?, - @SerialName("stop_sequences") val stopSequences: List? + @SerialName("stop_sequences") val stopSequences: List?, ) @Serializable data class Tool(val functionDeclarations: List) @@ -46,5 +46,5 @@ data class FunctionParameterProperties( val enum: List? = null, val properties: Map? = null, val required: List? = null, - val items: FunctionParameterProperties? = null -) \ No newline at end of file + val items: FunctionParameterProperties? = null, +) diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt index 7fd246c7..cc983fb8 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt @@ -55,14 +55,9 @@ data class Content(@EncodeDefault val role: String? = "user", val parts: List Date: Wed, 20 Mar 2024 18:10:51 -0700 Subject: [PATCH 3/6] full ktfmt --- .../com/google/ai/client/generativeai/Chat.kt | 4 +- .../generativeai/internal/util/conversions.kt | 6 +- .../generativeai/type/FunctionDeclarations.kt | 278 +++++++++--------- .../generativeai/type/FunctionParameter.kt | 2 +- .../generativeai/type/GenerativeBeta.kt | 2 +- .../ai/client/generativeai/type/Tool.kt | 4 +- .../ai/client/generativeai/type/Type.kt | 24 +- 7 files changed, 158 insertions(+), 162 deletions(-) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt index f85eab89..d5221b67 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt @@ -47,11 +47,11 @@ import kotlinx.coroutines.flow.transform */ class Chat(private val model: GenerativeModel, val history: MutableList = ArrayList()) { private var lock = Semaphore(1) - companion object{ + + companion object { private val VALID_ROLES = listOf("user", "function") } - /** * Generates a response from the backend with the provided [Content], and any previous ones * sent/returned from this chat. diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt index 5ce1f7e1..45ae00cf 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt @@ -34,9 +34,7 @@ import com.google.ai.client.generativeai.common.server.SafetyRating import com.google.ai.client.generativeai.common.shared.Blob import com.google.ai.client.generativeai.common.shared.BlobPart import com.google.ai.client.generativeai.common.shared.Content -import com.google.ai.client.generativeai.common.shared.FunctionCall import com.google.ai.client.generativeai.common.shared.FunctionCallPart -import com.google.ai.client.generativeai.common.shared.FunctionResponse import com.google.ai.client.generativeai.common.shared.FunctionResponsePart import com.google.ai.client.generativeai.common.shared.HarmBlockThreshold import com.google.ai.client.generativeai.common.shared.HarmCategory @@ -114,9 +112,7 @@ internal fun BlockThreshold.toInternal() = @GenerativeBeta internal fun Tool.toInternal() = - com.google.ai.client.generativeai.common.client.Tool( - functionDeclarations.map { it.toInternal() } - ) + com.google.ai.client.generativeai.common.client.Tool(functionDeclarations.map { it.toInternal() }) @GenerativeBeta internal fun FunctionDeclaration.toInternal() = diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt index 227e40f9..177934a1 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt @@ -28,15 +28,15 @@ import org.json.JSONObject */ @GenerativeBeta class NoParameterFunction( - name: String, - description: String, - val function: suspend () -> JSONObject, + name: String, + description: String, + val function: suspend () -> JSONObject, ) : FunctionDeclaration(name, description) { - override fun getParameters() = listOf>() + override fun getParameters() = listOf>() - suspend fun execute() = function() + suspend fun execute() = function() - override suspend fun execute(part: FunctionCallPart) = function() + override suspend fun execute(part: FunctionCallPart) = function() } /** @@ -50,17 +50,17 @@ class NoParameterFunction( */ @GenerativeBeta class OneParameterFunction( - name: String, - description: String, - val param: ParameterDeclaration, - val function: suspend (T) -> JSONObject, + name: String, + description: String, + val param: ParameterDeclaration, + val function: suspend (T) -> JSONObject, ) : FunctionDeclaration(name, description) { - override fun getParameters() = listOf(param) + override fun getParameters() = listOf(param) - override suspend fun execute(part: FunctionCallPart): JSONObject { - val arg1 = part.getArgOrThrow(param) - return function(arg1) - } + override suspend fun execute(part: FunctionCallPart): JSONObject { + val arg1 = part.getArgOrThrow(param) + return function(arg1) + } } /** @@ -75,19 +75,19 @@ class OneParameterFunction( */ @GenerativeBeta class TwoParameterFunction( - name: String, - description: String, - val param1: ParameterDeclaration, - val param2: ParameterDeclaration, - val function: suspend (T, U) -> JSONObject, + name: String, + description: String, + val param1: ParameterDeclaration, + val param2: ParameterDeclaration, + val function: suspend (T, U) -> JSONObject, ) : FunctionDeclaration(name, description) { - override fun getParameters() = listOf(param1, param2) + override fun getParameters() = listOf(param1, param2) - override suspend fun execute(part: FunctionCallPart): JSONObject { - val arg1 = part.getArgOrThrow(param1) - val arg2 = part.getArgOrThrow(param2) - return function(arg1, arg2) - } + override suspend fun execute(part: FunctionCallPart): JSONObject { + val arg1 = part.getArgOrThrow(param1) + val arg2 = part.getArgOrThrow(param2) + return function(arg1, arg2) + } } /** @@ -103,21 +103,21 @@ class TwoParameterFunction( */ @GenerativeBeta class ThreeParameterFunction( - name: String, - description: String, - val param1: ParameterDeclaration, - val param2: ParameterDeclaration, - val param3: ParameterDeclaration, - val function: suspend (T, U, V) -> JSONObject, + name: String, + description: String, + val param1: ParameterDeclaration, + val param2: ParameterDeclaration, + val param3: ParameterDeclaration, + val function: suspend (T, U, V) -> JSONObject, ) : FunctionDeclaration(name, description) { - override fun getParameters() = listOf(param1, param2, param3) + override fun getParameters() = listOf(param1, param2, param3) - override suspend fun execute(part: FunctionCallPart): JSONObject { - val arg1 = part.getArgOrThrow(param1) - val arg2 = part.getArgOrThrow(param2) - val arg3 = part.getArgOrThrow(param3) - return function(arg1, arg2, arg3) - } + override suspend fun execute(part: FunctionCallPart): JSONObject { + val arg1 = part.getArgOrThrow(param1) + val arg2 = part.getArgOrThrow(param2) + val arg3 = part.getArgOrThrow(param3) + return function(arg1, arg2, arg3) + } } /** @@ -134,30 +134,30 @@ class ThreeParameterFunction( */ @GenerativeBeta class FourParameterFunction( - name: String, - description: String, - val param1: ParameterDeclaration, - val param2: ParameterDeclaration, - val param3: ParameterDeclaration, - val param4: ParameterDeclaration, - val function: suspend (T, U, V, W) -> JSONObject, + name: String, + description: String, + val param1: ParameterDeclaration, + val param2: ParameterDeclaration, + val param3: ParameterDeclaration, + val param4: ParameterDeclaration, + val function: suspend (T, U, V, W) -> JSONObject, ) : FunctionDeclaration(name, description) { - override fun getParameters() = listOf(param1, param2, param3, param4) + override fun getParameters() = listOf(param1, param2, param3, param4) - override suspend fun execute(part: FunctionCallPart): JSONObject { - val arg1 = part.getArgOrThrow(param1) - val arg2 = part.getArgOrThrow(param2) - val arg3 = part.getArgOrThrow(param3) - val arg4 = part.getArgOrThrow(param4) - return function(arg1, arg2, arg3, arg4) - } + override suspend fun execute(part: FunctionCallPart): JSONObject { + val arg1 = part.getArgOrThrow(param1) + val arg2 = part.getArgOrThrow(param2) + val arg3 = part.getArgOrThrow(param3) + val arg4 = part.getArgOrThrow(param4) + return function(arg1, arg2, arg3, arg4) + } } @GenerativeBeta abstract class FunctionDeclaration(val name: String, val description: String) { - abstract fun getParameters(): List> + abstract fun getParameters(): List> - abstract suspend fun execute(part: FunctionCallPart): JSONObject + abstract suspend fun execute(part: FunctionCallPart): JSONObject } /** @@ -176,116 +176,116 @@ abstract class FunctionDeclaration(val name: String, val description: String) { * array */ class ParameterDeclaration( - val name: String, - val description: String, - val format: String? = null, - val enum: List? = null, - val properties: Map>? = null, - val required: List? = null, - val items: ParameterDeclaration? = null, - val type: FunctionType, + val name: String, + val description: String, + val format: String? = null, + val enum: List? = null, + val properties: Map>? = null, + val required: List? = null, + val items: ParameterDeclaration? = null, + val type: FunctionType, ) { - fun fromString(value: String?) = type.parse(value) + fun fromString(value: String?) = type.parse(value) - companion object { - fun int(name: String, description: String) = - ParameterDeclaration( - name = name, - description = description, - type = FunctionType.INTEGER, - ) + companion object { + fun int(name: String, description: String) = + ParameterDeclaration( + name = name, + description = description, + type = FunctionType.INTEGER, + ) - fun str(name: String, description: String) = - ParameterDeclaration( - name = name, - description = description, - type = FunctionType.STRING, - ) + fun str(name: String, description: String) = + ParameterDeclaration( + name = name, + description = description, + type = FunctionType.STRING, + ) - fun bool(name: String, description: String) = - ParameterDeclaration( - name = name, - description = description, - type = FunctionType.BOOLEAN, - ) + fun bool(name: String, description: String) = + ParameterDeclaration( + name = name, + description = description, + type = FunctionType.BOOLEAN, + ) - fun num(name: String, description: String) = - ParameterDeclaration( - name = name, - description = description, - type = FunctionType.NUMBER, - ) + fun num(name: String, description: String) = + ParameterDeclaration( + name = name, + description = description, + type = FunctionType.NUMBER, + ) - fun obj(name: String, description: String) = - ParameterDeclaration( - name = name, - description = description, - type = FunctionType.OBJECT, - ) + fun obj(name: String, description: String) = + ParameterDeclaration( + name = name, + description = description, + type = FunctionType.OBJECT, + ) - fun arr(name: String, description: String) = - ParameterDeclaration>( - name = name, - description = description, - type = FunctionType.ARRAY, - ) + fun arr(name: String, description: String) = + ParameterDeclaration>( + name = name, + description = description, + type = FunctionType.ARRAY, + ) - fun enum(name: String, description: String, values: List) = - ParameterDeclaration( - name = name, - description = description, - format = "enum", - enum = values, - type = FunctionType.STRING, - ) - } + fun enum(name: String, description: String, values: List) = + ParameterDeclaration( + name = name, + description = description, + format = "enum", + enum = values, + type = FunctionType.STRING, + ) + } } @GenerativeBeta fun defineFunction(name: String, description: String, function: suspend () -> JSONObject) = - NoParameterFunction(name, description, function) + NoParameterFunction(name, description, function) @GenerativeBeta fun defineFunction( - name: String, - description: String, - arg1: ParameterDeclaration, - function: suspend (T) -> JSONObject, + name: String, + description: String, + arg1: ParameterDeclaration, + function: suspend (T) -> JSONObject, ) = OneParameterFunction(name, description, arg1, function) @GenerativeBeta fun defineFunction( - name: String, - description: String, - arg1: ParameterDeclaration, - arg2: ParameterDeclaration, - function: suspend (T, U) -> JSONObject, + name: String, + description: String, + arg1: ParameterDeclaration, + arg2: ParameterDeclaration, + function: suspend (T, U) -> JSONObject, ) = TwoParameterFunction(name, description, arg1, arg2, function) @GenerativeBeta fun defineFunction( - name: String, - description: String, - arg1: ParameterDeclaration, - arg2: ParameterDeclaration, - arg3: ParameterDeclaration, - function: suspend (T, U, W) -> JSONObject, + name: String, + description: String, + arg1: ParameterDeclaration, + arg2: ParameterDeclaration, + arg3: ParameterDeclaration, + function: suspend (T, U, W) -> JSONObject, ) = ThreeParameterFunction(name, description, arg1, arg2, arg3, function) @GenerativeBeta fun defineFunction( - name: String, - description: String, - arg1: ParameterDeclaration, - arg2: ParameterDeclaration, - arg3: ParameterDeclaration, - arg4: ParameterDeclaration, - function: suspend (T, U, W, Z) -> JSONObject, + name: String, + description: String, + arg1: ParameterDeclaration, + arg2: ParameterDeclaration, + arg3: ParameterDeclaration, + arg4: ParameterDeclaration, + function: suspend (T, U, W, Z) -> JSONObject, ) = FourParameterFunction(name, description, arg1, arg2, arg3, arg4, function) private fun FunctionCallPart.getArgOrThrow(param: ParameterDeclaration): T { - return param.fromString(args[param.name]) - ?: throw RuntimeException( - "Missing argument for parameter \"${param.name}\" for function \"$name\"" - ) -} \ No newline at end of file + return param.fromString(args[param.name]) + ?: throw RuntimeException( + "Missing argument for parameter \"${param.name}\" for function \"$name\"" + ) +} diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt index 614f0eba..cb9ccbe1 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt @@ -16,4 +16,4 @@ package com.google.ai.client.generativeai.type -class FunctionParameter(val name: String, val description: String, val type: FunctionType) {} \ No newline at end of file +class FunctionParameter(val name: String, val description: String, val type: FunctionType) {} diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerativeBeta.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerativeBeta.kt index be0f0dd6..518f301d 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerativeBeta.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerativeBeta.kt @@ -19,4 +19,4 @@ package com.google.ai.client.generativeai.type @RequiresOptIn(message = "This API is only available on the v1beta endpoint") @Retention(AnnotationRetention.BINARY) @Target(AnnotationTarget.CLASS, AnnotationTarget.FUNCTION) -annotation class GenerativeBeta \ No newline at end of file +annotation class GenerativeBeta diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt index ecccff2a..67467710 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt @@ -24,5 +24,5 @@ package com.google.ai.client.generativeai.type */ @OptIn(GenerativeBeta::class) class Tool( - val functionDeclarations: List, -) \ No newline at end of file + val functionDeclarations: List, +) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Type.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Type.kt index a2c236db..c220031d 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Type.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Type.kt @@ -28,15 +28,15 @@ import org.json.JSONObject * @property T: the type of the object that this maps to in code. */ class FunctionType(val name: String, val parse: (String?) -> T?) { - companion object { - val STRING = FunctionType("STRING") { it } - val INTEGER = FunctionType("INTEGER") { it?.toLongOrNull() } - val NUMBER = FunctionType("NUMBER") { it?.toDoubleOrNull() } - val BOOLEAN = FunctionType("BOOLEAN") { it?.toBoolean() } - val ARRAY = - FunctionType>("ARRAY") { it -> - it?.let { Json.parseToJsonElement(it).jsonArray.map { element -> element.toString() } } - } - val OBJECT = FunctionType("OBJECT") { it?.let { JSONObject(it) } } - } -} \ No newline at end of file + companion object { + val STRING = FunctionType("STRING") { it } + val INTEGER = FunctionType("INTEGER") { it?.toLongOrNull() } + val NUMBER = FunctionType("NUMBER") { it?.toDoubleOrNull() } + val BOOLEAN = FunctionType("BOOLEAN") { it?.toBoolean() } + val ARRAY = + FunctionType>("ARRAY") { it -> + it?.let { Json.parseToJsonElement(it).jsonArray.map { element -> element.toString() } } + } + val OBJECT = FunctionType("OBJECT") { it?.let { JSONObject(it) } } + } +} From bfe0ae8e9826b386fce5227f900ae2a3941348cc Mon Sep 17 00:00:00 2001 From: Emily Ploszaj Date: Wed, 27 Mar 2024 14:50:10 -0500 Subject: [PATCH 4/6] Add constrained decoding support --- .../ai/client/generativeai/common/Request.kt | 2 + .../generativeai/common/client/Types.kt | 15 +++++++ .../ai/client/generativeai/GenerativeModel.kt | 5 +++ .../generativeai/internal/util/conversions.kt | 17 ++++++++ .../type/FunctionCallingConfig.kt | 43 +++++++++++++++++++ .../ai/client/generativeai/type/ToolConfig.kt | 34 +++++++++++++++ 6 files changed, 116 insertions(+) create mode 100644 generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionCallingConfig.kt create mode 100644 generativeai/src/main/java/com/google/ai/client/generativeai/type/ToolConfig.kt diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt index 4e484266..f874c0b4 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt @@ -18,6 +18,7 @@ package com.google.ai.client.generativeai.common import com.google.ai.client.generativeai.common.client.GenerationConfig import com.google.ai.client.generativeai.common.client.Tool +import com.google.ai.client.generativeai.common.client.ToolConfig import com.google.ai.client.generativeai.common.shared.Content import com.google.ai.client.generativeai.common.shared.SafetySetting import kotlinx.serialization.SerialName @@ -32,6 +33,7 @@ data class GenerateContentRequest( @SerialName("safety_settings") val safetySettings: List? = null, @SerialName("generation_config") val generationConfig: GenerationConfig? = null, val tools: List? = null, + @SerialName("tool_config") var toolConfig: ToolConfig? = null, ) : Request @Serializable diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt index 0a7335ce..d9be4d32 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt @@ -31,6 +31,21 @@ data class GenerationConfig( @Serializable data class Tool(val functionDeclarations: List) +@Serializable +data class ToolConfig( + @SerialName("function_calling_config") val functionCallingConfig: FunctionCallingConfig +) + +@Serializable +data class FunctionCallingConfig(val mode: Mode) { + @Serializable + enum class Mode { + @SerialName("MODE_UNSPECIFIED") UNSPECIFIED, + AUTO, + ANY, + NONE + } +} @Serializable data class FunctionDeclaration( val name: String, diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt index 69b29efb..2a3ebeda 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt @@ -40,6 +40,7 @@ import com.google.ai.client.generativeai.type.SafetySetting import com.google.ai.client.generativeai.type.SerializationException import com.google.ai.client.generativeai.type.ThreeParameterFunction import com.google.ai.client.generativeai.type.Tool +import com.google.ai.client.generativeai.type.ToolConfig import com.google.ai.client.generativeai.type.TwoParameterFunction import com.google.ai.client.generativeai.type.content import kotlinx.coroutines.flow.Flow @@ -66,6 +67,7 @@ internal constructor( val generationConfig: GenerationConfig? = null, val safetySettings: List? = null, val tools: List? = null, + val toolConfig: ToolConfig? = null, val requestOptions: RequestOptions = RequestOptions(), private val controller: APIController, ) { @@ -77,6 +79,7 @@ internal constructor( generationConfig: GenerationConfig? = null, safetySettings: List? = null, tools: List? = null, + toolConfig: ToolConfig? = null, requestOptions: RequestOptions = RequestOptions(), ) : this( modelName, @@ -84,6 +87,7 @@ internal constructor( generationConfig, safetySettings, tools, + toolConfig, requestOptions, APIController(apiKey, modelName, requestOptions.toInternal()), ) @@ -223,6 +227,7 @@ internal constructor( safetySettings?.map { it.toInternal() }, generationConfig?.toInternal(), tools?.map { it.toInternal() }, + toolConfig?.toInternal(), ) private fun constructCountTokensRequest(vararg prompt: Content) = diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt index 45ae00cf..7e9f01d9 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt @@ -43,12 +43,14 @@ import com.google.ai.client.generativeai.common.shared.SafetySetting import com.google.ai.client.generativeai.common.shared.TextPart import com.google.ai.client.generativeai.type.BlockThreshold import com.google.ai.client.generativeai.type.CitationMetadata +import com.google.ai.client.generativeai.type.FunctionCallingConfig import com.google.ai.client.generativeai.type.FunctionDeclaration import com.google.ai.client.generativeai.type.GenerativeBeta import com.google.ai.client.generativeai.type.ImagePart import com.google.ai.client.generativeai.type.ParameterDeclaration import com.google.ai.client.generativeai.type.SerializationException import com.google.ai.client.generativeai.type.Tool +import com.google.ai.client.generativeai.type.ToolConfig import com.google.ai.client.generativeai.type.content import java.io.ByteArrayOutputStream import kotlinx.serialization.decodeFromString @@ -114,6 +116,21 @@ internal fun BlockThreshold.toInternal() = internal fun Tool.toInternal() = com.google.ai.client.generativeai.common.client.Tool(functionDeclarations.map { it.toInternal() }) +@GenerativeBeta +internal fun ToolConfig.toInternal() = + com.google.ai.client.generativeai.common.client.ToolConfig( + com.google.ai.client.generativeai.common.client.FunctionCallingConfig( + when (functionCallingConfig.mode) { + FunctionCallingConfig.Mode.ANY -> + com.google.ai.client.generativeai.common.client.FunctionCallingConfig.Mode.ANY + FunctionCallingConfig.Mode.AUTO -> + com.google.ai.client.generativeai.common.client.FunctionCallingConfig.Mode.AUTO + FunctionCallingConfig.Mode.NONE -> + com.google.ai.client.generativeai.common.client.FunctionCallingConfig.Mode.NONE + } + ) + ) + @GenerativeBeta internal fun FunctionDeclaration.toInternal() = com.google.ai.client.generativeai.common.client.FunctionDeclaration( diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionCallingConfig.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionCallingConfig.kt new file mode 100644 index 00000000..a2da761c --- /dev/null +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionCallingConfig.kt @@ -0,0 +1,43 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.type + +/** + * Contains configuration for function calling from the model. This can be used to force function + * calling predictions or disable them. + * + * @param mode The function calling mode of the model + */ +@GenerativeBeta +class FunctionCallingConfig(val mode: Mode) { + enum class Mode { + /** + * The default behavior for function calling. The model calls functions to answer queries at its + * discretion + */ + AUTO, + + /** The model always predicts a provided function call to answer every query. */ + ANY, + + /** + * The model will never predict a function call to answer a query. This can also be achieved by + * not passing any tools to the model. + */ + NONE + } +} diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/ToolConfig.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/ToolConfig.kt new file mode 100644 index 00000000..cc392320 --- /dev/null +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/ToolConfig.kt @@ -0,0 +1,34 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.type + +/** + * Contains configuration for the function calling tools of the model. This can be used to change + * when the model can predict function calls. + * + * @param functionCallingConfig The config for function calling + */ +@OptIn(GenerativeBeta::class) +class ToolConfig(val functionCallingConfig: FunctionCallingConfig) { + + companion object { + /** Shorthand to construct a ToolConfig that restricts the model from calling any functions */ + fun never(): ToolConfig = ToolConfig(FunctionCallingConfig(FunctionCallingConfig.Mode.NONE)) + /** Shorthand to construct a ToolConfig that restricts the model to always call some function */ + fun always(): ToolConfig = ToolConfig(FunctionCallingConfig(FunctionCallingConfig.Mode.ANY)) + } +} From b5dda3012d783c2b0dbbc2d396346462594a020c Mon Sep 17 00:00:00 2001 From: Emily Ploszaj Date: Thu, 28 Mar 2024 11:41:01 -0500 Subject: [PATCH 5/6] Adjust formatting --- .../com/google/ai/client/generativeai/common/client/Types.kt | 1 + 1 file changed, 1 insertion(+) diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt index d9be4d32..e80a95ce 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt @@ -46,6 +46,7 @@ data class FunctionCallingConfig(val mode: Mode) { NONE } } + @Serializable data class FunctionDeclaration( val name: String, From fd3081fe392f5aada65763f6b53d1999862fa7a6 Mon Sep 17 00:00:00 2001 From: Emily Ploszaj Date: Tue, 2 Apr 2024 10:43:15 -0500 Subject: [PATCH 6/6] Add serialization tests --- .../generativeai/common/APIControllerTests.kt | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt index a05d6d67..6104ab41 100644 --- a/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt +++ b/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt @@ -16,17 +16,21 @@ package com.google.ai.client.generativeai.common +import com.google.ai.client.generativeai.common.client.FunctionCallingConfig +import com.google.ai.client.generativeai.common.client.ToolConfig import com.google.ai.client.generativeai.common.shared.Content import com.google.ai.client.generativeai.common.shared.TextPart import com.google.ai.client.generativeai.common.util.commonTest import com.google.ai.client.generativeai.common.util.createResponses import com.google.ai.client.generativeai.common.util.doBlocking import com.google.ai.client.generativeai.common.util.prepareStreamingResponse +import io.kotest.assertions.json.shouldContainJsonKey import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.shouldBe import io.kotest.matchers.string.shouldContain import io.ktor.client.engine.mock.MockEngine import io.ktor.client.engine.mock.respond +import io.ktor.content.TextContent import io.ktor.http.HttpHeaders import io.ktor.http.HttpStatusCode import io.ktor.http.headersOf @@ -114,6 +118,38 @@ internal class EndpointTests { mockEngine.requestHistory.first().url.host shouldBe "my.custom.endpoint" } + + @Test + fun `ToolConfig serialization is correct`() = doBlocking { + val channel = ByteChannel(autoFlush = true) + val mockEngine = MockEngine { + respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) + } + prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) } + + val controller = + APIController("super_cool_test_key", "gemini-pro-1.0", RequestOptions(), mockEngine) + + withTimeout(5.seconds) { + controller + .generateContentStream( + GenerateContentRequest( + model = "unused", + contents = listOf(Content(parts = listOf(TextPart("Arbitrary")))), + toolConfig = + ToolConfig( + functionCallingConfig = + FunctionCallingConfig(mode = FunctionCallingConfig.Mode.AUTO) + ) + ) + ) + .collect { channel.close() } + } + + val requestBodyAsText = (mockEngine.requestHistory.first().body as TextContent).text + + requestBodyAsText shouldContainJsonKey "tool_config.function_calling_config.mode" + } } @RunWith(Parameterized::class)