diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/FirebaseVertexAI.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/FirebaseVertexAI.kt index 7499127571a..19abfa68d58 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/FirebaseVertexAI.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/FirebaseVertexAI.kt @@ -25,6 +25,7 @@ import com.google.firebase.vertexai.type.GenerationConfig import com.google.firebase.vertexai.type.RequestOptions import com.google.firebase.vertexai.type.SafetySetting import com.google.firebase.vertexai.type.Tool +import com.google.firebase.vertexai.type.ToolConfig /** Entry point for all Firebase Vertex AI functionality. */ class FirebaseVertexAI( @@ -36,8 +37,11 @@ class FirebaseVertexAI( * A facilitator for a given multimodal model (eg; Gemini). * * @param modelName name of the model in the backend + * @param generationConfig configuration parameters to use for content generation * @param safetySettings the safety bounds to use during alongside prompts during content * generation + * @param tools the list of tools to make available to the model + * @param toolConfig the configuration that defines how the model handles the tools provided * @param requestOptions configuration options to utilize during backend communication */ fun generativeModel( @@ -45,6 +49,7 @@ class FirebaseVertexAI( generationConfig: GenerationConfig? = null, safetySettings: List? = null, tools: List? = null, + toolConfig: ToolConfig? = null, requestOptions: RequestOptions = RequestOptions(), ) = GenerativeModel( @@ -53,6 +58,7 @@ class FirebaseVertexAI( generationConfig, safetySettings, tools, + toolConfig, requestOptions, appCheckProvider.get() ) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/GenerativeModel.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/GenerativeModel.kt index 8f122f5ffce..7a7390a0996 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/GenerativeModel.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/GenerativeModel.kt @@ -37,6 +37,7 @@ import com.google.firebase.vertexai.type.ResponseStoppedException import com.google.firebase.vertexai.type.SafetySetting import com.google.firebase.vertexai.type.SerializationException import com.google.firebase.vertexai.type.Tool +import com.google.firebase.vertexai.type.ToolConfig import com.google.firebase.vertexai.type.content import kotlin.time.Duration import kotlin.time.Duration.Companion.seconds @@ -63,6 +64,7 @@ internal constructor( val safetySettings: List? = null, val requestOptions: RequestOptions = RequestOptions(), val tools: List? = null, + val toolConfig: ToolConfig? = null, private val controller: APIController ) { @@ -73,6 +75,7 @@ internal constructor( generationConfig: GenerationConfig? = null, safetySettings: List? = null, tools: List? = null, + toolConfig: ToolConfig? = null, requestOptions: RequestOptions = RequestOptions(), appCheckTokenProvider: InteropAppCheckTokenProvider? = null ) : this( @@ -82,6 +85,7 @@ internal constructor( safetySettings, requestOptions, tools, + toolConfig, APIController( apiKey, modelName, @@ -212,6 +216,7 @@ internal constructor( safetySettings?.map { it.toInternal() }, generationConfig?.toInternal(), tools?.map { it.toInternal() }, + toolConfig?.toInternal() ) private fun constructCountTokensRequest(vararg prompt: Content) = diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt index d0ac383606d..0aa16e1156d 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt @@ -32,6 +32,7 @@ import com.google.firebase.vertexai.type.CitationMetadata import com.google.firebase.vertexai.type.Content import com.google.firebase.vertexai.type.CountTokensResponse import com.google.firebase.vertexai.type.FinishReason +import com.google.firebase.vertexai.type.FunctionCallingConfig import com.google.firebase.vertexai.type.FunctionDeclaration import com.google.firebase.vertexai.type.GenerateContentResponse import com.google.firebase.vertexai.type.GenerationConfig @@ -46,6 +47,7 @@ import com.google.firebase.vertexai.type.SafetySetting import com.google.firebase.vertexai.type.SerializationException import com.google.firebase.vertexai.type.TextPart import com.google.firebase.vertexai.type.Tool +import com.google.firebase.vertexai.type.ToolConfig import com.google.firebase.vertexai.type.UsageMetadata import com.google.firebase.vertexai.type.content import java.io.ByteArrayOutputStream @@ -121,6 +123,20 @@ internal fun com.google.firebase.vertexai.type.HarmCategory.toInternal() = HarmCategory.UNKNOWN -> com.google.ai.client.generativeai.common.shared.HarmCategory.UNKNOWN } +internal fun ToolConfig.toInternal() = + com.google.ai.client.generativeai.common.client.ToolConfig( + com.google.ai.client.generativeai.common.client.FunctionCallingConfig( + when (functionCallingConfig.mode) { + FunctionCallingConfig.Mode.ANY -> + com.google.ai.client.generativeai.common.client.FunctionCallingConfig.Mode.ANY + FunctionCallingConfig.Mode.AUTO -> + com.google.ai.client.generativeai.common.client.FunctionCallingConfig.Mode.AUTO + FunctionCallingConfig.Mode.NONE -> + com.google.ai.client.generativeai.common.client.FunctionCallingConfig.Mode.NONE + } + ) + ) + internal fun BlockThreshold.toInternal() = when (this) { BlockThreshold.NONE -> diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt new file mode 100644 index 00000000000..6e711718add --- /dev/null +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt @@ -0,0 +1,42 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.vertexai.type + +/** + * Contains configuration for function calling from the model. This can be used to force function + * calling predictions or disable them. + * + * @param mode The function calling mode of the model + */ +class FunctionCallingConfig(val mode: Mode) { + enum class Mode { + /** + * The default behavior for function calling. The model calls functions to answer queries at its + * discretion + */ + AUTO, + + /** The model always predicts a provided function call to answer every query. */ + ANY, + + /** + * The model will never predict a function call to answer a query. This can also be achieved by + * not passing any tools to the model. + */ + NONE + } +} diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt new file mode 100644 index 00000000000..9575cb18844 --- /dev/null +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt @@ -0,0 +1,33 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.vertexai.type + +/** + * Contains configuration for the function calling tools of the model. This can be used to change + * when the model can predict function calls. + * + * @param functionCallingConfig The config for function calling + */ +class ToolConfig(val functionCallingConfig: FunctionCallingConfig) { + + companion object { + /** Shorthand to construct a ToolConfig that restricts the model from calling any functions */ + fun never(): ToolConfig = ToolConfig(FunctionCallingConfig(FunctionCallingConfig.Mode.NONE)) + /** Shorthand to construct a ToolConfig that restricts the model to always call some function */ + fun always(): ToolConfig = ToolConfig(FunctionCallingConfig(FunctionCallingConfig.Mode.ANY)) + } +}