Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grounding with Google search feature #232

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ data class Tool(
val functionDeclarations: List<FunctionDeclaration>? = null,
// This is a json object because it is not possible to make a data class with no parameters.
val codeExecution: JsonObject? = null,
val googleSearchRetrieval: GoogleSearchRetrieval? = null,
)

@Serializable
Expand All @@ -60,6 +61,17 @@ data class FunctionCallingConfig(val mode: Mode) {
@Serializable
data class FunctionDeclaration(val name: String, val description: String, val parameters: Schema)

@Serializable data class GoogleSearchRetrieval(val dynamicRetrievalConfig: DynamicRetrievalConfig)

@Serializable
data class DynamicRetrievalConfig(val mode: Mode, val dynamicThreshold: Float?) {
@Serializable
enum class Mode {
@SerialName("MODE_DYNAMIC") DYNAMIC,
@SerialName("MODE_UNSPECIFIED") UNSPECIFIED
}
}

@Serializable
data class Schema(
val type: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ import android.util.Base64
import com.google.ai.client.generativeai.common.CountTokensResponse
import com.google.ai.client.generativeai.common.GenerateContentResponse
import com.google.ai.client.generativeai.common.RequestOptions
import com.google.ai.client.generativeai.common.client.DynamicRetrievalConfig
import com.google.ai.client.generativeai.common.client.GenerationConfig
import com.google.ai.client.generativeai.common.client.GoogleSearchRetrieval
import com.google.ai.client.generativeai.common.client.Schema
import com.google.ai.client.generativeai.common.server.BlockReason
import com.google.ai.client.generativeai.common.server.Candidate
Expand Down Expand Up @@ -52,6 +54,7 @@ import com.google.ai.client.generativeai.common.shared.SafetySetting
import com.google.ai.client.generativeai.common.shared.TextPart
import com.google.ai.client.generativeai.type.BlockThreshold
import com.google.ai.client.generativeai.type.CitationMetadata
import com.google.ai.client.generativeai.type.DynamicRetrievalMode
import com.google.ai.client.generativeai.type.ExecutionOutcome
import com.google.ai.client.generativeai.type.FunctionCallingConfig
import com.google.ai.client.generativeai.type.FunctionDeclaration
Expand Down Expand Up @@ -144,6 +147,7 @@ internal fun Tool.toInternal() =
com.google.ai.client.generativeai.common.client.Tool(
functionDeclarations?.map { it.toInternal() },
codeExecution = codeExecution?.toInternal(),
googleSearchRetrieval = googleSearchRetrieval?.toInternal(),
)

internal fun ToolConfig.toInternal() =
Expand Down Expand Up @@ -189,6 +193,24 @@ internal fun <T> com.google.ai.client.generativeai.type.Schema<T>.toInternal():

internal fun JSONObject.toInternal() = Json.decodeFromString<JsonObject>(toString())

internal fun com.google.ai.client.generativeai.type.GoogleSearchRetrieval.toInternal():
GoogleSearchRetrieval {
return GoogleSearchRetrieval(dynamicRetrievalConfig.toInternal())
}

internal fun com.google.ai.client.generativeai.type.DynamicRetrievalConfig.toInternal():
DynamicRetrievalConfig {
return DynamicRetrievalConfig(
when (mode) {
DynamicRetrievalMode.DYNAMIC ->
com.google.ai.client.generativeai.common.client.DynamicRetrievalConfig.Mode.DYNAMIC
DynamicRetrievalMode.UNSPECIFIED ->
com.google.ai.client.generativeai.common.client.DynamicRetrievalConfig.Mode.UNSPECIFIED
},
dynamicThreshold,
)
}

internal fun Candidate.toPublic(): com.google.ai.client.generativeai.type.Candidate {
val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty()
val citations = citationMetadata?.citationSources?.map { it.toPublic() }.orEmpty()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.ai.client.generativeai.type

import androidx.annotation.FloatRange

/*
* Specifies the dynamic retrieval configuration for the given source.
*/
data class DynamicRetrievalConfig(
/*
* The mode of the predictor to be used in dynamic retrieval.
*/
val mode: DynamicRetrievalMode,
/*
* (Optional) The threshold to be used in dynamic retrieval. If not set, a system default value is used.
*/
@FloatRange(0.0, 1.0) val dynamicThreshold: Float? = null,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.ai.client.generativeai.type

/*
* The mode of the predictor to be used in [DynamicRetrievalConfig].
*/
enum class DynamicRetrievalMode {
DYNAMIC,
UNSPECIFIED
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.ai.client.generativeai.type

/** Retrieval tool that is powered by Google search. */
data class GoogleSearchRetrieval(val dynamicRetrievalConfig: DynamicRetrievalConfig)
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ import org.json.JSONObject
*
* @param functionDeclarations The set of functions that this tool allows the model access to
* @param codeExecution This is a flag value to enable Code Execution. Use [CODE_EXECUTION].
* @param googleSearchRetrieval This is a Retrieval tool that is powered by Google search.
*/
class Tool
@JvmOverloads
constructor(
val functionDeclarations: List<FunctionDeclaration>? = null,
val codeExecution: JSONObject? = null,
val googleSearchRetrieval: GoogleSearchRetrieval? = null,
) {
companion object {
@JvmField val CODE_EXECUTION = Tool(codeExecution = JSONObject())
Expand Down