From bfe0ae8e9826b386fce5227f900ae2a3941348cc Mon Sep 17 00:00:00 2001
From: Emily Ploszaj <emilyploszaj@google.com>
Date: Wed, 27 Mar 2024 14:50:10 -0500
Subject: [PATCH] Add constrained decoding support

---
 .../ai/client/generativeai/common/Request.kt  |  2 +
 .../generativeai/common/client/Types.kt       | 15 +++++++
 .../ai/client/generativeai/GenerativeModel.kt |  5 +++
 .../generativeai/internal/util/conversions.kt | 17 ++++++++
 .../type/FunctionCallingConfig.kt             | 43 +++++++++++++++++++
 .../ai/client/generativeai/type/ToolConfig.kt | 34 +++++++++++++++
 6 files changed, 116 insertions(+)
 create mode 100644 generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionCallingConfig.kt
 create mode 100644 generativeai/src/main/java/com/google/ai/client/generativeai/type/ToolConfig.kt

diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt
index 4e484266..f874c0b4 100644
--- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt
+++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/Request.kt
@@ -18,6 +18,7 @@ package com.google.ai.client.generativeai.common
 
 import com.google.ai.client.generativeai.common.client.GenerationConfig
 import com.google.ai.client.generativeai.common.client.Tool
+import com.google.ai.client.generativeai.common.client.ToolConfig
 import com.google.ai.client.generativeai.common.shared.Content
 import com.google.ai.client.generativeai.common.shared.SafetySetting
 import kotlinx.serialization.SerialName
@@ -32,6 +33,7 @@ data class GenerateContentRequest(
   @SerialName("safety_settings") val safetySettings: List<SafetySetting>? = null,
   @SerialName("generation_config") val generationConfig: GenerationConfig? = null,
   val tools: List<Tool>? = null,
+  @SerialName("tool_config") var toolConfig: ToolConfig? = null,
 ) : Request
 
 @Serializable
diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt
index 0a7335ce..d9be4d32 100644
--- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt
+++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt
@@ -31,6 +31,21 @@ data class GenerationConfig(
 
 @Serializable data class Tool(val functionDeclarations: List<FunctionDeclaration>)
 
+@Serializable
+data class ToolConfig(
+  @SerialName("function_calling_config") val functionCallingConfig: FunctionCallingConfig
+)
+
+@Serializable
+data class FunctionCallingConfig(val mode: Mode) {
+  @Serializable
+  enum class Mode {
+    @SerialName("MODE_UNSPECIFIED") UNSPECIFIED,
+    AUTO,
+    ANY,
+    NONE
+  }
+}
 @Serializable
 data class FunctionDeclaration(
   val name: String,
diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt
index 69b29efb..2a3ebeda 100644
--- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt
+++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt
@@ -40,6 +40,7 @@ import com.google.ai.client.generativeai.type.SafetySetting
 import com.google.ai.client.generativeai.type.SerializationException
 import com.google.ai.client.generativeai.type.ThreeParameterFunction
 import com.google.ai.client.generativeai.type.Tool
+import com.google.ai.client.generativeai.type.ToolConfig
 import com.google.ai.client.generativeai.type.TwoParameterFunction
 import com.google.ai.client.generativeai.type.content
 import kotlinx.coroutines.flow.Flow
@@ -66,6 +67,7 @@ internal constructor(
   val generationConfig: GenerationConfig? = null,
   val safetySettings: List<SafetySetting>? = null,
   val tools: List<Tool>? = null,
+  val toolConfig: ToolConfig? = null,
   val requestOptions: RequestOptions = RequestOptions(),
   private val controller: APIController,
 ) {
@@ -77,6 +79,7 @@ internal constructor(
     generationConfig: GenerationConfig? = null,
     safetySettings: List<SafetySetting>? = null,
     tools: List<Tool>? = null,
+    toolConfig: ToolConfig? = null,
     requestOptions: RequestOptions = RequestOptions(),
   ) : this(
     modelName,
@@ -84,6 +87,7 @@ internal constructor(
     generationConfig,
     safetySettings,
     tools,
+    toolConfig,
     requestOptions,
     APIController(apiKey, modelName, requestOptions.toInternal()),
   )
@@ -223,6 +227,7 @@ internal constructor(
       safetySettings?.map { it.toInternal() },
       generationConfig?.toInternal(),
       tools?.map { it.toInternal() },
+      toolConfig?.toInternal(),
     )
 
   private fun constructCountTokensRequest(vararg prompt: Content) =
diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt
index 45ae00cf..7e9f01d9 100644
--- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt
+++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt
@@ -43,12 +43,14 @@ import com.google.ai.client.generativeai.common.shared.SafetySetting
 import com.google.ai.client.generativeai.common.shared.TextPart
 import com.google.ai.client.generativeai.type.BlockThreshold
 import com.google.ai.client.generativeai.type.CitationMetadata
+import com.google.ai.client.generativeai.type.FunctionCallingConfig
 import com.google.ai.client.generativeai.type.FunctionDeclaration
 import com.google.ai.client.generativeai.type.GenerativeBeta
 import com.google.ai.client.generativeai.type.ImagePart
 import com.google.ai.client.generativeai.type.ParameterDeclaration
 import com.google.ai.client.generativeai.type.SerializationException
 import com.google.ai.client.generativeai.type.Tool
+import com.google.ai.client.generativeai.type.ToolConfig
 import com.google.ai.client.generativeai.type.content
 import java.io.ByteArrayOutputStream
 import kotlinx.serialization.decodeFromString
@@ -114,6 +116,21 @@ internal fun BlockThreshold.toInternal() =
 internal fun Tool.toInternal() =
   com.google.ai.client.generativeai.common.client.Tool(functionDeclarations.map { it.toInternal() })
 
+@GenerativeBeta
+internal fun ToolConfig.toInternal() =
+  com.google.ai.client.generativeai.common.client.ToolConfig(
+    com.google.ai.client.generativeai.common.client.FunctionCallingConfig(
+      when (functionCallingConfig.mode) {
+        FunctionCallingConfig.Mode.ANY ->
+          com.google.ai.client.generativeai.common.client.FunctionCallingConfig.Mode.ANY
+        FunctionCallingConfig.Mode.AUTO ->
+          com.google.ai.client.generativeai.common.client.FunctionCallingConfig.Mode.AUTO
+        FunctionCallingConfig.Mode.NONE ->
+          com.google.ai.client.generativeai.common.client.FunctionCallingConfig.Mode.NONE
+      }
+    )
+  )
+
 @GenerativeBeta
 internal fun FunctionDeclaration.toInternal() =
   com.google.ai.client.generativeai.common.client.FunctionDeclaration(
diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionCallingConfig.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionCallingConfig.kt
new file mode 100644
index 00000000..a2da761c
--- /dev/null
+++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionCallingConfig.kt
@@ -0,0 +1,43 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.ai.client.generativeai.type
+
+/**
+ * Contains configuration for function calling from the model. This can be used to force function
+ * calling predictions or disable them.
+ *
+ * @param mode The function calling mode of the model
+ */
+@GenerativeBeta
+class FunctionCallingConfig(val mode: Mode) {
+  enum class Mode {
+    /**
+     * The default behavior for function calling. The model calls functions to answer queries at its
+     * discretion
+     */
+    AUTO,
+
+    /** The model always predicts a provided function call to answer every query. */
+    ANY,
+
+    /**
+     * The model will never predict a function call to answer a query. This can also be achieved by
+     * not passing any tools to the model.
+     */
+    NONE
+  }
+}
diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/ToolConfig.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/ToolConfig.kt
new file mode 100644
index 00000000..cc392320
--- /dev/null
+++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/ToolConfig.kt
@@ -0,0 +1,34 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.ai.client.generativeai.type
+
+/**
+ * Contains configuration for the function calling tools of the model. This can be used to change
+ * when the model can predict function calls.
+ *
+ * @param functionCallingConfig The config for function calling
+ */
+@OptIn(GenerativeBeta::class)
+class ToolConfig(val functionCallingConfig: FunctionCallingConfig) {
+
+  companion object {
+    /** Shorthand to construct a ToolConfig that restricts the model from calling any functions */
+    fun never(): ToolConfig = ToolConfig(FunctionCallingConfig(FunctionCallingConfig.Mode.NONE))
+    /** Shorthand to construct a ToolConfig that restricts the model to always call some function */
+    fun always(): ToolConfig = ToolConfig(FunctionCallingConfig(FunctionCallingConfig.Mode.ANY))
+  }
+}