Skip to content

Commit

Permalink
Add support for constrain decoding (#5837)
Browse files Browse the repository at this point in the history
Based on the changes made in
google-gemini/generative-ai-android#103
  • Loading branch information
rlazo authored Apr 5, 2024
1 parent 5cad43f commit 687e98b
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.google.firebase.vertexai.type.GenerationConfig
import com.google.firebase.vertexai.type.RequestOptions
import com.google.firebase.vertexai.type.SafetySetting
import com.google.firebase.vertexai.type.Tool
import com.google.firebase.vertexai.type.ToolConfig

/** Entry point for all Firebase Vertex AI functionality. */
class FirebaseVertexAI(
Expand All @@ -36,15 +37,19 @@ class FirebaseVertexAI(
* A facilitator for a given multimodal model (eg; Gemini).
*
* @param modelName name of the model in the backend
* @param generationConfig configuration parameters to use for content generation
* @param safetySettings the safety bounds to use during alongside prompts during content
* generation
* @param tools the list of tools to make available to the model
* @param toolConfig the configuration that defines how the model handles the tools provided
* @param requestOptions configuration options to utilize during backend communication
*/
fun generativeModel(
modelName: String,
generationConfig: GenerationConfig? = null,
safetySettings: List<SafetySetting>? = null,
tools: List<Tool>? = null,
toolConfig: ToolConfig? = null,
requestOptions: RequestOptions = RequestOptions(),
) =
GenerativeModel(
Expand All @@ -53,6 +58,7 @@ class FirebaseVertexAI(
generationConfig,
safetySettings,
tools,
toolConfig,
requestOptions,
appCheckProvider.get()
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import com.google.firebase.vertexai.type.ResponseStoppedException
import com.google.firebase.vertexai.type.SafetySetting
import com.google.firebase.vertexai.type.SerializationException
import com.google.firebase.vertexai.type.Tool
import com.google.firebase.vertexai.type.ToolConfig
import com.google.firebase.vertexai.type.content
import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds
Expand All @@ -63,6 +64,7 @@ internal constructor(
val safetySettings: List<SafetySetting>? = null,
val requestOptions: RequestOptions = RequestOptions(),
val tools: List<Tool>? = null,
val toolConfig: ToolConfig? = null,
private val controller: APIController
) {

Expand All @@ -73,6 +75,7 @@ internal constructor(
generationConfig: GenerationConfig? = null,
safetySettings: List<SafetySetting>? = null,
tools: List<Tool>? = null,
toolConfig: ToolConfig? = null,
requestOptions: RequestOptions = RequestOptions(),
appCheckTokenProvider: InteropAppCheckTokenProvider? = null
) : this(
Expand All @@ -82,6 +85,7 @@ internal constructor(
safetySettings,
requestOptions,
tools,
toolConfig,
APIController(
apiKey,
modelName,
Expand Down Expand Up @@ -212,6 +216,7 @@ internal constructor(
safetySettings?.map { it.toInternal() },
generationConfig?.toInternal(),
tools?.map { it.toInternal() },
toolConfig?.toInternal()
)

private fun constructCountTokensRequest(vararg prompt: Content) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import com.google.firebase.vertexai.type.CitationMetadata
import com.google.firebase.vertexai.type.Content
import com.google.firebase.vertexai.type.CountTokensResponse
import com.google.firebase.vertexai.type.FinishReason
import com.google.firebase.vertexai.type.FunctionCallingConfig
import com.google.firebase.vertexai.type.FunctionDeclaration
import com.google.firebase.vertexai.type.GenerateContentResponse
import com.google.firebase.vertexai.type.GenerationConfig
Expand All @@ -46,6 +47,7 @@ import com.google.firebase.vertexai.type.SafetySetting
import com.google.firebase.vertexai.type.SerializationException
import com.google.firebase.vertexai.type.TextPart
import com.google.firebase.vertexai.type.Tool
import com.google.firebase.vertexai.type.ToolConfig
import com.google.firebase.vertexai.type.UsageMetadata
import com.google.firebase.vertexai.type.content
import java.io.ByteArrayOutputStream
Expand Down Expand Up @@ -121,6 +123,20 @@ internal fun com.google.firebase.vertexai.type.HarmCategory.toInternal() =
HarmCategory.UNKNOWN -> com.google.ai.client.generativeai.common.shared.HarmCategory.UNKNOWN
}

internal fun ToolConfig.toInternal() =
com.google.ai.client.generativeai.common.client.ToolConfig(
com.google.ai.client.generativeai.common.client.FunctionCallingConfig(
when (functionCallingConfig.mode) {
FunctionCallingConfig.Mode.ANY ->
com.google.ai.client.generativeai.common.client.FunctionCallingConfig.Mode.ANY
FunctionCallingConfig.Mode.AUTO ->
com.google.ai.client.generativeai.common.client.FunctionCallingConfig.Mode.AUTO
FunctionCallingConfig.Mode.NONE ->
com.google.ai.client.generativeai.common.client.FunctionCallingConfig.Mode.NONE
}
)
)

internal fun BlockThreshold.toInternal() =
when (this) {
BlockThreshold.NONE ->
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.firebase.vertexai.type

/**
* Contains configuration for function calling from the model. This can be used to force function
* calling predictions or disable them.
*
* @param mode The function calling mode of the model
*/
class FunctionCallingConfig(val mode: Mode) {
enum class Mode {
/**
* The default behavior for function calling. The model calls functions to answer queries at its
* discretion
*/
AUTO,

/** The model always predicts a provided function call to answer every query. */
ANY,

/**
* The model will never predict a function call to answer a query. This can also be achieved by
* not passing any tools to the model.
*/
NONE
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.firebase.vertexai.type

/**
* Contains configuration for the function calling tools of the model. This can be used to change
* when the model can predict function calls.
*
* @param functionCallingConfig The config for function calling
*/
class ToolConfig(val functionCallingConfig: FunctionCallingConfig) {

companion object {
/** Shorthand to construct a ToolConfig that restricts the model from calling any functions */
fun never(): ToolConfig = ToolConfig(FunctionCallingConfig(FunctionCallingConfig.Mode.NONE))
/** Shorthand to construct a ToolConfig that restricts the model to always call some function */
fun always(): ToolConfig = ToolConfig(FunctionCallingConfig(FunctionCallingConfig.Mode.ANY))
}
}

0 comments on commit 687e98b

Please sign in to comment.