diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt index cb6deff5..18b7dded 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt @@ -18,6 +18,7 @@ package com.google.ai.client.generativeai.common import com.google.ai.client.generativeai.common.client.GenerationConfig import com.google.ai.client.generativeai.common.client.Tool +import com.google.ai.client.generativeai.common.client.ToolConfig import com.google.ai.client.generativeai.common.shared.Content import com.google.ai.client.generativeai.common.shared.SafetySetting import kotlinx.serialization.SerialName @@ -33,6 +34,7 @@ data class GenerateContentRequest( @SerialName("safety_settings") val safetySettings: List? = null, @SerialName("generation_config") val generationConfig: GenerationConfig? = null, val tools: List? = null, + @SerialName("tool_config") var toolConfig: ToolConfig? = null, ) : Request @Serializable diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt index 00d44cc9..c3bff52b 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt @@ -31,6 +31,22 @@ data class GenerationConfig( @Serializable data class Tool(val functionDeclarations: List) +@Serializable +data class ToolConfig( + @SerialName("function_calling_config") val functionCallingConfig: FunctionCallingConfig +) + +@Serializable +data class FunctionCallingConfig(val mode: Mode) { + @Serializable + enum class Mode { + @SerialName("MODE_UNSPECIFIED") UNSPECIFIED, + AUTO, + ANY, + NONE + } +} + @Serializable data class FunctionDeclaration( val name: String, diff --git a/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt index 7674a553..6318b50c 100644 --- a/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt +++ b/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt @@ -16,6 +16,8 @@ package com.google.ai.client.generativeai.common +import com.google.ai.client.generativeai.common.client.FunctionCallingConfig +import com.google.ai.client.generativeai.common.client.ToolConfig import com.google.ai.client.generativeai.common.shared.Content import com.google.ai.client.generativeai.common.shared.TextPart import com.google.ai.client.generativeai.common.util.commonTest @@ -157,6 +159,38 @@ internal class RequestFormatTests { requestBodyAsText shouldContainJsonKey "contents" requestBodyAsText shouldNotContainJsonKey "model" } + + @Test + fun `ToolConfig serialization contains correct keys`() = doBlocking { + val channel = ByteChannel(autoFlush = true) + val mockEngine = MockEngine { + respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) + } + prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) } + + val controller = + APIController("super_cool_test_key", "gemini-pro-1.0", RequestOptions(), mockEngine) + + withTimeout(5.seconds) { + controller + .generateContentStream( + GenerateContentRequest( + model = "unused", + contents = listOf(Content(parts = listOf(TextPart("Arbitrary")))), + toolConfig = + ToolConfig( + functionCallingConfig = + FunctionCallingConfig(mode = FunctionCallingConfig.Mode.AUTO) + ) + ) + ) + .collect { channel.close() } + } + + val requestBodyAsText = (mockEngine.requestHistory.first().body as TextContent).text + + requestBodyAsText shouldContainJsonKey "tool_config.function_calling_config.mode" + } } @RunWith(Parameterized::class) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt index 90dc8daa..a2f285f1 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt @@ -41,6 +41,7 @@ import com.google.ai.client.generativeai.type.SafetySetting import com.google.ai.client.generativeai.type.SerializationException import com.google.ai.client.generativeai.type.ThreeParameterFunction import com.google.ai.client.generativeai.type.Tool +import com.google.ai.client.generativeai.type.ToolConfig import com.google.ai.client.generativeai.type.TwoParameterFunction import com.google.ai.client.generativeai.type.content import kotlinx.coroutines.flow.Flow @@ -67,6 +68,7 @@ internal constructor( val generationConfig: GenerationConfig? = null, val safetySettings: List? = null, val tools: List? = null, + val toolConfig: ToolConfig? = null, val requestOptions: RequestOptions = RequestOptions(), private val controller: APIController, ) { @@ -78,6 +80,7 @@ internal constructor( generationConfig: GenerationConfig? = null, safetySettings: List? = null, tools: List? = null, + toolConfig: ToolConfig? = null, requestOptions: RequestOptions = RequestOptions(), ) : this( modelName, @@ -85,6 +88,7 @@ internal constructor( generationConfig, safetySettings, tools, + toolConfig, requestOptions, APIController(apiKey, modelName, requestOptions.toInternal()), ) @@ -225,6 +229,7 @@ internal constructor( safetySettings?.map { it.toInternal() }, generationConfig?.toInternal(), tools?.map { it.toInternal() }, + toolConfig?.toInternal(), ) private fun constructCountTokensRequest(vararg prompt: Content) = diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt index c1b813e5..d2be5aa4 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt @@ -45,11 +45,13 @@ import com.google.ai.client.generativeai.common.shared.SafetySetting import com.google.ai.client.generativeai.common.shared.TextPart import com.google.ai.client.generativeai.type.BlockThreshold import com.google.ai.client.generativeai.type.CitationMetadata +import com.google.ai.client.generativeai.type.FunctionCallingConfig import com.google.ai.client.generativeai.type.FunctionDeclaration import com.google.ai.client.generativeai.type.GenerativeBeta import com.google.ai.client.generativeai.type.ImagePart import com.google.ai.client.generativeai.type.SerializationException import com.google.ai.client.generativeai.type.Tool +import com.google.ai.client.generativeai.type.ToolConfig import com.google.ai.client.generativeai.type.content import java.io.ByteArrayOutputStream import kotlinx.serialization.json.Json @@ -118,6 +120,21 @@ internal fun BlockThreshold.toInternal() = internal fun Tool.toInternal() = com.google.ai.client.generativeai.common.client.Tool(functionDeclarations.map { it.toInternal() }) +@GenerativeBeta +internal fun ToolConfig.toInternal() = + com.google.ai.client.generativeai.common.client.ToolConfig( + com.google.ai.client.generativeai.common.client.FunctionCallingConfig( + when (functionCallingConfig.mode) { + FunctionCallingConfig.Mode.ANY -> + com.google.ai.client.generativeai.common.client.FunctionCallingConfig.Mode.ANY + FunctionCallingConfig.Mode.AUTO -> + com.google.ai.client.generativeai.common.client.FunctionCallingConfig.Mode.AUTO + FunctionCallingConfig.Mode.NONE -> + com.google.ai.client.generativeai.common.client.FunctionCallingConfig.Mode.NONE + } + ) + ) + @GenerativeBeta internal fun FunctionDeclaration.toInternal() = com.google.ai.client.generativeai.common.client.FunctionDeclaration( diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionCallingConfig.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionCallingConfig.kt new file mode 100644 index 00000000..a2da761c --- /dev/null +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionCallingConfig.kt @@ -0,0 +1,43 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.type + +/** + * Contains configuration for function calling from the model. This can be used to force function + * calling predictions or disable them. + * + * @param mode The function calling mode of the model + */ +@GenerativeBeta +class FunctionCallingConfig(val mode: Mode) { + enum class Mode { + /** + * The default behavior for function calling. The model calls functions to answer queries at its + * discretion + */ + AUTO, + + /** The model always predicts a provided function call to answer every query. */ + ANY, + + /** + * The model will never predict a function call to answer a query. This can also be achieved by + * not passing any tools to the model. + */ + NONE + } +} diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/ToolConfig.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/ToolConfig.kt new file mode 100644 index 00000000..cc392320 --- /dev/null +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/ToolConfig.kt @@ -0,0 +1,34 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.type + +/** + * Contains configuration for the function calling tools of the model. This can be used to change + * when the model can predict function calls. + * + * @param functionCallingConfig The config for function calling + */ +@OptIn(GenerativeBeta::class) +class ToolConfig(val functionCallingConfig: FunctionCallingConfig) { + + companion object { + /** Shorthand to construct a ToolConfig that restricts the model from calling any functions */ + fun never(): ToolConfig = ToolConfig(FunctionCallingConfig(FunctionCallingConfig.Mode.NONE)) + /** Shorthand to construct a ToolConfig that restricts the model to always call some function */ + fun always(): ToolConfig = ToolConfig(FunctionCallingConfig(FunctionCallingConfig.Mode.ANY)) + } +}