diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt index 4e484266..f874c0b4 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt @@ -18,6 +18,7 @@ package com.google.ai.client.generativeai.common import com.google.ai.client.generativeai.common.client.GenerationConfig import com.google.ai.client.generativeai.common.client.Tool +import com.google.ai.client.generativeai.common.client.ToolConfig import com.google.ai.client.generativeai.common.shared.Content import com.google.ai.client.generativeai.common.shared.SafetySetting import kotlinx.serialization.SerialName @@ -32,6 +33,7 @@ data class GenerateContentRequest( @SerialName("safety_settings") val safetySettings: List? = null, @SerialName("generation_config") val generationConfig: GenerationConfig? = null, val tools: List? = null, + @SerialName("tool_config") var toolConfig: ToolConfig? = null, ) : Request @Serializable diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt index 0a7335ce..d9be4d32 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt @@ -31,6 +31,21 @@ data class GenerationConfig( @Serializable data class Tool(val functionDeclarations: List) +@Serializable +data class ToolConfig( + @SerialName("function_calling_config") val functionCallingConfig: FunctionCallingConfig +) + +@Serializable +data class FunctionCallingConfig(val mode: Mode) { + @Serializable + enum class Mode { + @SerialName("MODE_UNSPECIFIED") UNSPECIFIED, + AUTO, + ANY, + NONE + } +} @Serializable data class FunctionDeclaration( val name: String, diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt index 69b29efb..2a3ebeda 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt @@ -40,6 +40,7 @@ import com.google.ai.client.generativeai.type.SafetySetting import com.google.ai.client.generativeai.type.SerializationException import com.google.ai.client.generativeai.type.ThreeParameterFunction import com.google.ai.client.generativeai.type.Tool +import com.google.ai.client.generativeai.type.ToolConfig import com.google.ai.client.generativeai.type.TwoParameterFunction import com.google.ai.client.generativeai.type.content import kotlinx.coroutines.flow.Flow @@ -66,6 +67,7 @@ internal constructor( val generationConfig: GenerationConfig? = null, val safetySettings: List? = null, val tools: List? = null, + val toolConfig: ToolConfig? = null, val requestOptions: RequestOptions = RequestOptions(), private val controller: APIController, ) { @@ -77,6 +79,7 @@ internal constructor( generationConfig: GenerationConfig? = null, safetySettings: List? = null, tools: List? = null, + toolConfig: ToolConfig? = null, requestOptions: RequestOptions = RequestOptions(), ) : this( modelName, @@ -84,6 +87,7 @@ internal constructor( generationConfig, safetySettings, tools, + toolConfig, requestOptions, APIController(apiKey, modelName, requestOptions.toInternal()), ) @@ -223,6 +227,7 @@ internal constructor( safetySettings?.map { it.toInternal() }, generationConfig?.toInternal(), tools?.map { it.toInternal() }, + toolConfig?.toInternal(), ) private fun constructCountTokensRequest(vararg prompt: Content) = diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt index 45ae00cf..7e9f01d9 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt @@ -43,12 +43,14 @@ import com.google.ai.client.generativeai.common.shared.SafetySetting import com.google.ai.client.generativeai.common.shared.TextPart import com.google.ai.client.generativeai.type.BlockThreshold import com.google.ai.client.generativeai.type.CitationMetadata +import com.google.ai.client.generativeai.type.FunctionCallingConfig import com.google.ai.client.generativeai.type.FunctionDeclaration import com.google.ai.client.generativeai.type.GenerativeBeta import com.google.ai.client.generativeai.type.ImagePart import com.google.ai.client.generativeai.type.ParameterDeclaration import com.google.ai.client.generativeai.type.SerializationException import com.google.ai.client.generativeai.type.Tool +import com.google.ai.client.generativeai.type.ToolConfig import com.google.ai.client.generativeai.type.content import java.io.ByteArrayOutputStream import kotlinx.serialization.decodeFromString @@ -114,6 +116,21 @@ internal fun BlockThreshold.toInternal() = internal fun Tool.toInternal() = com.google.ai.client.generativeai.common.client.Tool(functionDeclarations.map { it.toInternal() }) +@GenerativeBeta +internal fun ToolConfig.toInternal() = + com.google.ai.client.generativeai.common.client.ToolConfig( + com.google.ai.client.generativeai.common.client.FunctionCallingConfig( + when (functionCallingConfig.mode) { + FunctionCallingConfig.Mode.ANY -> + com.google.ai.client.generativeai.common.client.FunctionCallingConfig.Mode.ANY + FunctionCallingConfig.Mode.AUTO -> + com.google.ai.client.generativeai.common.client.FunctionCallingConfig.Mode.AUTO + FunctionCallingConfig.Mode.NONE -> + com.google.ai.client.generativeai.common.client.FunctionCallingConfig.Mode.NONE + } + ) + ) + @GenerativeBeta internal fun FunctionDeclaration.toInternal() = com.google.ai.client.generativeai.common.client.FunctionDeclaration( diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionCallingConfig.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionCallingConfig.kt new file mode 100644 index 00000000..a2da761c --- /dev/null +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionCallingConfig.kt @@ -0,0 +1,43 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.type + +/** + * Contains configuration for function calling from the model. This can be used to force function + * calling predictions or disable them. + * + * @param mode The function calling mode of the model + */ +@GenerativeBeta +class FunctionCallingConfig(val mode: Mode) { + enum class Mode { + /** + * The default behavior for function calling. The model calls functions to answer queries at its + * discretion + */ + AUTO, + + /** The model always predicts a provided function call to answer every query. */ + ANY, + + /** + * The model will never predict a function call to answer a query. This can also be achieved by + * not passing any tools to the model. + */ + NONE + } +} diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/ToolConfig.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/ToolConfig.kt new file mode 100644 index 00000000..cc392320 --- /dev/null +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/ToolConfig.kt @@ -0,0 +1,34 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.type + +/** + * Contains configuration for the function calling tools of the model. This can be used to change + * when the model can predict function calls. + * + * @param functionCallingConfig The config for function calling + */ +@OptIn(GenerativeBeta::class) +class ToolConfig(val functionCallingConfig: FunctionCallingConfig) { + + companion object { + /** Shorthand to construct a ToolConfig that restricts the model from calling any functions */ + fun never(): ToolConfig = ToolConfig(FunctionCallingConfig(FunctionCallingConfig.Mode.NONE)) + /** Shorthand to construct a ToolConfig that restricts the model to always call some function */ + fun always(): ToolConfig = ToolConfig(FunctionCallingConfig(FunctionCallingConfig.Mode.ANY)) + } +}