Skip to content

Commit

Permalink
Add constrained decoding support (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
emilypgoogle authored Apr 3, 2024
1 parent f22a52f commit f14808f
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.google.ai.client.generativeai.common

import com.google.ai.client.generativeai.common.client.GenerationConfig
import com.google.ai.client.generativeai.common.client.Tool
import com.google.ai.client.generativeai.common.client.ToolConfig
import com.google.ai.client.generativeai.common.shared.Content
import com.google.ai.client.generativeai.common.shared.SafetySetting
import kotlinx.serialization.SerialName
Expand All @@ -33,6 +34,7 @@ data class GenerateContentRequest(
@SerialName("safety_settings") val safetySettings: List<SafetySetting>? = null,
@SerialName("generation_config") val generationConfig: GenerationConfig? = null,
val tools: List<Tool>? = null,
@SerialName("tool_config") var toolConfig: ToolConfig? = null,
) : Request

@Serializable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,22 @@ data class GenerationConfig(

@Serializable data class Tool(val functionDeclarations: List<FunctionDeclaration>)

@Serializable
data class ToolConfig(
@SerialName("function_calling_config") val functionCallingConfig: FunctionCallingConfig
)

@Serializable
data class FunctionCallingConfig(val mode: Mode) {
@Serializable
enum class Mode {
@SerialName("MODE_UNSPECIFIED") UNSPECIFIED,
AUTO,
ANY,
NONE
}
}

@Serializable
data class FunctionDeclaration(
val name: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package com.google.ai.client.generativeai.common

import com.google.ai.client.generativeai.common.client.FunctionCallingConfig
import com.google.ai.client.generativeai.common.client.ToolConfig
import com.google.ai.client.generativeai.common.shared.Content
import com.google.ai.client.generativeai.common.shared.TextPart
import com.google.ai.client.generativeai.common.util.commonTest
Expand Down Expand Up @@ -157,6 +159,38 @@ internal class RequestFormatTests {
requestBodyAsText shouldContainJsonKey "contents"
requestBodyAsText shouldNotContainJsonKey "model"
}

@Test
fun `ToolConfig serialization contains correct keys`() = doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }

val controller =
APIController("super_cool_test_key", "gemini-pro-1.0", RequestOptions(), mockEngine)

withTimeout(5.seconds) {
controller
.generateContentStream(
GenerateContentRequest(
model = "unused",
contents = listOf(Content(parts = listOf(TextPart("Arbitrary")))),
toolConfig =
ToolConfig(
functionCallingConfig =
FunctionCallingConfig(mode = FunctionCallingConfig.Mode.AUTO)
)
)
)
.collect { channel.close() }
}

val requestBodyAsText = (mockEngine.requestHistory.first().body as TextContent).text

requestBodyAsText shouldContainJsonKey "tool_config.function_calling_config.mode"
}
}

@RunWith(Parameterized::class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import com.google.ai.client.generativeai.type.SafetySetting
import com.google.ai.client.generativeai.type.SerializationException
import com.google.ai.client.generativeai.type.ThreeParameterFunction
import com.google.ai.client.generativeai.type.Tool
import com.google.ai.client.generativeai.type.ToolConfig
import com.google.ai.client.generativeai.type.TwoParameterFunction
import com.google.ai.client.generativeai.type.content
import kotlinx.coroutines.flow.Flow
Expand All @@ -67,6 +68,7 @@ internal constructor(
val generationConfig: GenerationConfig? = null,
val safetySettings: List<SafetySetting>? = null,
val tools: List<Tool>? = null,
val toolConfig: ToolConfig? = null,
val requestOptions: RequestOptions = RequestOptions(),
private val controller: APIController,
) {
Expand All @@ -78,13 +80,15 @@ internal constructor(
generationConfig: GenerationConfig? = null,
safetySettings: List<SafetySetting>? = null,
tools: List<Tool>? = null,
toolConfig: ToolConfig? = null,
requestOptions: RequestOptions = RequestOptions(),
) : this(
modelName,
apiKey,
generationConfig,
safetySettings,
tools,
toolConfig,
requestOptions,
APIController(apiKey, modelName, requestOptions.toInternal()),
)
Expand Down Expand Up @@ -225,6 +229,7 @@ internal constructor(
safetySettings?.map { it.toInternal() },
generationConfig?.toInternal(),
tools?.map { it.toInternal() },
toolConfig?.toInternal(),
)

private fun constructCountTokensRequest(vararg prompt: Content) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@ import com.google.ai.client.generativeai.common.shared.SafetySetting
import com.google.ai.client.generativeai.common.shared.TextPart
import com.google.ai.client.generativeai.type.BlockThreshold
import com.google.ai.client.generativeai.type.CitationMetadata
import com.google.ai.client.generativeai.type.FunctionCallingConfig
import com.google.ai.client.generativeai.type.FunctionDeclaration
import com.google.ai.client.generativeai.type.GenerativeBeta
import com.google.ai.client.generativeai.type.ImagePart
import com.google.ai.client.generativeai.type.SerializationException
import com.google.ai.client.generativeai.type.Tool
import com.google.ai.client.generativeai.type.ToolConfig
import com.google.ai.client.generativeai.type.content
import java.io.ByteArrayOutputStream
import kotlinx.serialization.json.Json
Expand Down Expand Up @@ -118,6 +120,21 @@ internal fun BlockThreshold.toInternal() =
internal fun Tool.toInternal() =
com.google.ai.client.generativeai.common.client.Tool(functionDeclarations.map { it.toInternal() })

@GenerativeBeta
internal fun ToolConfig.toInternal() =
com.google.ai.client.generativeai.common.client.ToolConfig(
com.google.ai.client.generativeai.common.client.FunctionCallingConfig(
when (functionCallingConfig.mode) {
FunctionCallingConfig.Mode.ANY ->
com.google.ai.client.generativeai.common.client.FunctionCallingConfig.Mode.ANY
FunctionCallingConfig.Mode.AUTO ->
com.google.ai.client.generativeai.common.client.FunctionCallingConfig.Mode.AUTO
FunctionCallingConfig.Mode.NONE ->
com.google.ai.client.generativeai.common.client.FunctionCallingConfig.Mode.NONE
}
)
)

@GenerativeBeta
internal fun FunctionDeclaration.toInternal() =
com.google.ai.client.generativeai.common.client.FunctionDeclaration(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.ai.client.generativeai.type

/**
* Contains configuration for function calling from the model. This can be used to force function
* calling predictions or disable them.
*
* @param mode The function calling mode of the model
*/
@GenerativeBeta
class FunctionCallingConfig(val mode: Mode) {
enum class Mode {
/**
* The default behavior for function calling. The model calls functions to answer queries at its
* discretion
*/
AUTO,

/** The model always predicts a provided function call to answer every query. */
ANY,

/**
* The model will never predict a function call to answer a query. This can also be achieved by
* not passing any tools to the model.
*/
NONE
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.ai.client.generativeai.type

/**
* Contains configuration for the function calling tools of the model. This can be used to change
* when the model can predict function calls.
*
* @param functionCallingConfig The config for function calling
*/
@OptIn(GenerativeBeta::class)
class ToolConfig(val functionCallingConfig: FunctionCallingConfig) {

companion object {
/** Shorthand to construct a ToolConfig that restricts the model from calling any functions */
fun never(): ToolConfig = ToolConfig(FunctionCallingConfig(FunctionCallingConfig.Mode.NONE))
/** Shorthand to construct a ToolConfig that restricts the model to always call some function */
fun always(): ToolConfig = ToolConfig(FunctionCallingConfig(FunctionCallingConfig.Mode.ANY))
}
}

0 comments on commit f14808f

Please sign in to comment.