diff --git a/.changes/cloud-camp-bait-calculator.json b/.changes/cloud-camp-bait-calculator.json new file mode 100644 index 00000000..ddb946a0 --- /dev/null +++ b/.changes/cloud-camp-bait-calculator.json @@ -0,0 +1 @@ +{"type":"MAJOR","changes":["Add function calling"]} diff --git a/generativeai/build.gradle.kts b/generativeai/build.gradle.kts index fbe489ea..ba88cefb 100644 --- a/generativeai/build.gradle.kts +++ b/generativeai/build.gradle.kts @@ -79,6 +79,7 @@ dependencies { implementation("org.slf4j:slf4j-nop:2.0.9") implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.7.3") implementation("org.jetbrains.kotlinx:kotlinx-coroutines-reactive:1.7.3") + implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1") implementation("org.reactivestreams:reactive-streams:1.0.3") implementation("com.google.guava:listenablefuture:1.0") diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt index b536c244..fbe8f346 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/Chat.kt @@ -167,8 +167,8 @@ class Chat(private val model: GenerativeModel, val history: MutableList } private fun Content.assertComesFromUser() { - if (role != "user") { - throw InvalidStateException("Chat prompts should come from the 'user' role.") + if (role !in listOf("user", "function")) { + throw InvalidStateException("Chat prompts should come from the 'user' or 'function' role.") } } diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt index 690f5be4..90dc8daa 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt @@ -25,18 +25,29 @@ import com.google.ai.client.generativeai.internal.util.toPublic import com.google.ai.client.generativeai.type.Content import com.google.ai.client.generativeai.type.CountTokensResponse import com.google.ai.client.generativeai.type.FinishReason +import com.google.ai.client.generativeai.type.FourParameterFunction +import com.google.ai.client.generativeai.type.FunctionCallPart import com.google.ai.client.generativeai.type.GenerateContentResponse import com.google.ai.client.generativeai.type.GenerationConfig +import com.google.ai.client.generativeai.type.GenerativeBeta import com.google.ai.client.generativeai.type.GoogleGenerativeAIException +import com.google.ai.client.generativeai.type.InvalidStateException +import com.google.ai.client.generativeai.type.NoParameterFunction +import com.google.ai.client.generativeai.type.OneParameterFunction import com.google.ai.client.generativeai.type.PromptBlockedException import com.google.ai.client.generativeai.type.RequestOptions import com.google.ai.client.generativeai.type.ResponseStoppedException import com.google.ai.client.generativeai.type.SafetySetting import com.google.ai.client.generativeai.type.SerializationException +import com.google.ai.client.generativeai.type.ThreeParameterFunction +import com.google.ai.client.generativeai.type.Tool +import com.google.ai.client.generativeai.type.TwoParameterFunction import com.google.ai.client.generativeai.type.content import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.catch import kotlinx.coroutines.flow.map +import kotlinx.serialization.ExperimentalSerializationApi +import org.json.JSONObject /** * A facilitator for a given multimodal model (eg; Gemini). @@ -48,14 +59,16 @@ import kotlinx.coroutines.flow.map * generation * @property requestOptions configuration options to utilize during backend communication */ +@OptIn(ExperimentalSerializationApi::class) class GenerativeModel internal constructor( val modelName: String, val apiKey: String, val generationConfig: GenerationConfig? = null, val safetySettings: List? = null, + val tools: List? = null, val requestOptions: RequestOptions = RequestOptions(), - private val controller: APIController + private val controller: APIController, ) { @JvmOverloads @@ -64,14 +77,16 @@ internal constructor( apiKey: String, generationConfig: GenerationConfig? = null, safetySettings: List? = null, + tools: List? = null, requestOptions: RequestOptions = RequestOptions(), ) : this( modelName, apiKey, generationConfig, safetySettings, + tools, requestOptions, - APIController(apiKey, modelName, requestOptions.toInternal()) + APIController(apiKey, modelName, requestOptions.toInternal()), ) /** @@ -171,12 +186,45 @@ internal constructor( return countTokens(content { image(prompt) }) } + /** + * Executes a function requested by the model. + * + * @param functionCallPart A [FunctionCallPart] from the model, containing a function call and + * parameters + * @return The output of the requested function call + */ + @OptIn(GenerativeBeta::class) + suspend fun executeFunction(functionCallPart: FunctionCallPart): JSONObject { + if (tools == null) { + throw InvalidStateException("No registered tools") + } + val callable = + tools.flatMap { it.functionDeclarations }.firstOrNull { it.name == functionCallPart.name } + ?: throw InvalidStateException("No registered function named ${functionCallPart.name}") + return when (callable) { + is NoParameterFunction -> callable.execute() + is OneParameterFunction<*> -> + (callable as OneParameterFunction).execute(functionCallPart) + is TwoParameterFunction<*, *> -> + (callable as TwoParameterFunction).execute(functionCallPart) + is ThreeParameterFunction<*, *, *> -> + (callable as ThreeParameterFunction).execute(functionCallPart) + is FourParameterFunction<*, *, *, *> -> + (callable as FourParameterFunction).execute(functionCallPart) + else -> { + throw RuntimeException("UNREACHABLE") + } + } + } + + @OptIn(GenerativeBeta::class) private fun constructRequest(vararg prompt: Content) = GenerateContentRequest( modelName, prompt.map { it.toInternal() }, safetySettings?.map { it.toInternal() }, - generationConfig?.toInternal() + generationConfig?.toInternal(), + tools?.map { it.toInternal() }, ) private fun constructCountTokensRequest(vararg prompt: Content) = diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt index 5abf8509..c1b813e5 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt @@ -23,6 +23,7 @@ import com.google.ai.client.generativeai.common.CountTokensResponse import com.google.ai.client.generativeai.common.GenerateContentResponse import com.google.ai.client.generativeai.common.RequestOptions import com.google.ai.client.generativeai.common.client.GenerationConfig +import com.google.ai.client.generativeai.common.client.Schema import com.google.ai.client.generativeai.common.server.BlockReason import com.google.ai.client.generativeai.common.server.Candidate import com.google.ai.client.generativeai.common.server.CitationSources @@ -33,6 +34,10 @@ import com.google.ai.client.generativeai.common.server.SafetyRating import com.google.ai.client.generativeai.common.shared.Blob import com.google.ai.client.generativeai.common.shared.BlobPart import com.google.ai.client.generativeai.common.shared.Content +import com.google.ai.client.generativeai.common.shared.FunctionCall +import com.google.ai.client.generativeai.common.shared.FunctionCallPart +import com.google.ai.client.generativeai.common.shared.FunctionResponse +import com.google.ai.client.generativeai.common.shared.FunctionResponsePart import com.google.ai.client.generativeai.common.shared.HarmBlockThreshold import com.google.ai.client.generativeai.common.shared.HarmCategory import com.google.ai.client.generativeai.common.shared.Part @@ -40,10 +45,16 @@ import com.google.ai.client.generativeai.common.shared.SafetySetting import com.google.ai.client.generativeai.common.shared.TextPart import com.google.ai.client.generativeai.type.BlockThreshold import com.google.ai.client.generativeai.type.CitationMetadata +import com.google.ai.client.generativeai.type.FunctionDeclaration +import com.google.ai.client.generativeai.type.GenerativeBeta import com.google.ai.client.generativeai.type.ImagePart import com.google.ai.client.generativeai.type.SerializationException +import com.google.ai.client.generativeai.type.Tool import com.google.ai.client.generativeai.type.content import java.io.ByteArrayOutputStream +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonObject +import org.json.JSONObject private const val BASE_64_FLAGS = Base64.NO_WRAP @@ -59,6 +70,10 @@ internal fun com.google.ai.client.generativeai.type.Part.toInternal(): Part { is ImagePart -> BlobPart(Blob("image/jpeg", encodeBitmapToBase64Png(image))) is com.google.ai.client.generativeai.type.BlobPart -> BlobPart(Blob(mimeType, Base64.encodeToString(blob, BASE_64_FLAGS))) + is com.google.ai.client.generativeai.type.FunctionCallPart -> + FunctionCallPart(FunctionCall(name, args.orEmpty())) + is com.google.ai.client.generativeai.type.FunctionResponsePart -> + FunctionResponsePart(FunctionResponse(name, response.toInternal())) else -> throw SerializationException( "The given subclass of Part (${javaClass.simpleName}) is not supported in the serialization yet." @@ -76,7 +91,7 @@ internal fun com.google.ai.client.generativeai.type.GenerationConfig.toInternal( topK = topK, candidateCount = candidateCount, maxOutputTokens = maxOutputTokens, - stopSequences = stopSequences + stopSequences = stopSequences, ) internal fun com.google.ai.client.generativeai.type.HarmCategory.toInternal() = @@ -99,6 +114,35 @@ internal fun BlockThreshold.toInternal() = BlockThreshold.UNSPECIFIED -> HarmBlockThreshold.UNSPECIFIED } +@GenerativeBeta +internal fun Tool.toInternal() = + com.google.ai.client.generativeai.common.client.Tool(functionDeclarations.map { it.toInternal() }) + +@GenerativeBeta +internal fun FunctionDeclaration.toInternal() = + com.google.ai.client.generativeai.common.client.FunctionDeclaration( + name, + description, + Schema( + properties = getParameters().associate { it.name to it.toInternal() }, + required = getParameters().map { it.name }, + type = "OBJECT", + ), + ) + +internal fun com.google.ai.client.generativeai.type.Schema.toInternal(): Schema = + Schema( + type.name, + description, + format, + enum, + properties?.mapValues { it.value.toInternal() }, + required, + items?.toInternal(), + ) + +internal fun JSONObject.toInternal() = Json.decodeFromString(toString()) + internal fun Candidate.toPublic(): com.google.ai.client.generativeai.type.Candidate { val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty() val citations = citationMetadata?.citationSources?.map { it.toPublic() }.orEmpty() @@ -108,7 +152,7 @@ internal fun Candidate.toPublic(): com.google.ai.client.generativeai.type.Candid this.content?.toPublic() ?: content("model") {}, safetyRatings, citations, - finishReason + finishReason, ) } @@ -126,6 +170,16 @@ internal fun Part.toPublic(): com.google.ai.client.generativeai.type.Part { com.google.ai.client.generativeai.type.BlobPart(inlineData.mimeType, data) } } + is FunctionCallPart -> + com.google.ai.client.generativeai.type.FunctionCallPart( + functionCall.name, + functionCall.args.orEmpty(), + ) + is FunctionResponsePart -> + com.google.ai.client.generativeai.type.FunctionResponsePart( + functionResponse.name, + functionResponse.response.toPublic(), + ) else -> throw SerializationException( "Unsupported part type \"${javaClass.simpleName}\" provided. This model may not be supported by this SDK." @@ -192,12 +246,14 @@ internal fun BlockReason.toPublic() = internal fun GenerateContentResponse.toPublic() = com.google.ai.client.generativeai.type.GenerateContentResponse( candidates?.map { it.toPublic() }.orEmpty(), - promptFeedback?.toPublic() + promptFeedback?.toPublic(), ) internal fun CountTokensResponse.toPublic() = com.google.ai.client.generativeai.type.CountTokensResponse(totalTokens) +internal fun JsonObject.toPublic() = JSONObject(toString()) + private fun encodeBitmapToBase64Png(input: Bitmap): String { ByteArrayOutputStream().let { input.compress(Bitmap.CompressFormat.JPEG, 80, it) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt new file mode 100644 index 00000000..0382a411 --- /dev/null +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt @@ -0,0 +1,276 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.type + +import org.json.JSONObject + +/** + * A declared function, including implementation, that a model can be given access to in order to + * gain info or complete tasks. + * + * @property name The name of the function call, this should be clear and descriptive for the model + * @property description A description of what the function does and its output. + * @property function the function implementation + */ +@GenerativeBeta +class NoParameterFunction( + name: String, + description: String, + val function: suspend () -> JSONObject, +) : FunctionDeclaration(name, description) { + override fun getParameters() = listOf>() + + suspend fun execute() = function() + + override suspend fun execute(part: FunctionCallPart) = function() +} + +/** + * A declared function, including implementation, that a model can be given access to in order to + * gain info or complete tasks. + * + * @property name The name of the function call, this should be clear and descriptive for the model + * @property description A description of what the function does and its output. + * @property param A description of the first function parameter + * @property function the function implementation + */ +@GenerativeBeta +class OneParameterFunction( + name: String, + description: String, + val param: Schema, + val function: suspend (T) -> JSONObject, +) : FunctionDeclaration(name, description) { + override fun getParameters() = listOf(param) + + override suspend fun execute(part: FunctionCallPart): JSONObject { + val arg1 = part.getArgOrThrow(param) + return function(arg1) + } +} + +/** + * A declared function, including implementation, that a model can be given access to in order to + * gain info or complete tasks. + * + * @property name The name of the function call, this should be clear and descriptive for the model + * @property description A description of what the function does and its output. + * @property param1 A description of the first function parameter + * @property param2 A description of the second function parameter + * @property function the function implementation + */ +@GenerativeBeta +class TwoParameterFunction( + name: String, + description: String, + val param1: Schema, + val param2: Schema, + val function: suspend (T, U) -> JSONObject, +) : FunctionDeclaration(name, description) { + override fun getParameters() = listOf(param1, param2) + + override suspend fun execute(part: FunctionCallPart): JSONObject { + val arg1 = part.getArgOrThrow(param1) + val arg2 = part.getArgOrThrow(param2) + return function(arg1, arg2) + } +} + +/** + * A declared function, including implementation, that a model can be given access to in order to + * gain info or complete tasks. + * + * @property name The name of the function call, this should be clear and descriptive for the model + * @property description A description of what the function does and its output. + * @property param1 A description of the first function parameter + * @property param2 A description of the second function parameter + * @property param3 A description of the third function parameter + * @property function the function implementation + */ +@GenerativeBeta +class ThreeParameterFunction( + name: String, + description: String, + val param1: Schema, + val param2: Schema, + val param3: Schema, + val function: suspend (T, U, V) -> JSONObject, +) : FunctionDeclaration(name, description) { + override fun getParameters() = listOf(param1, param2, param3) + + override suspend fun execute(part: FunctionCallPart): JSONObject { + val arg1 = part.getArgOrThrow(param1) + val arg2 = part.getArgOrThrow(param2) + val arg3 = part.getArgOrThrow(param3) + return function(arg1, arg2, arg3) + } +} + +/** + * A declared function, including implementation, that a model can be given access to in order to + * gain info or complete tasks. + * + * @property name The name of the function call, this should be clear and descriptive for the model + * @property description A description of what the function does and its output. + * @property param1 A description of the first function parameter + * @property param2 A description of the second function parameter + * @property param3 A description of the third function parameter + * @property param4 A description of the fourth function parameter + * @property function the function implementation + */ +@GenerativeBeta +class FourParameterFunction( + name: String, + description: String, + val param1: Schema, + val param2: Schema, + val param3: Schema, + val param4: Schema, + val function: suspend (T, U, V, W) -> JSONObject, +) : FunctionDeclaration(name, description) { + override fun getParameters() = listOf(param1, param2, param3, param4) + + override suspend fun execute(part: FunctionCallPart): JSONObject { + val arg1 = part.getArgOrThrow(param1) + val arg2 = part.getArgOrThrow(param2) + val arg3 = part.getArgOrThrow(param3) + val arg4 = part.getArgOrThrow(param4) + return function(arg1, arg2, arg3, arg4) + } +} + +@GenerativeBeta +abstract class FunctionDeclaration(val name: String, val description: String) { + abstract fun getParameters(): List> + + abstract suspend fun execute(part: FunctionCallPart): JSONObject +} + +/** + * Represents a parameter for a declared function + * + * @property name: The name of the parameter + * @property description: The description of what the parameter should contain or represent + * @property format: format information for the parameter, this can include bitlength in the case of + * int/float or keywords like "enum" for the string type + * @property enum: contains the enum values for a string enum + * @property type: contains the type info and parser + * @property properties: if type is OBJECT, then this contains the description of the fields of the + * object by name + * @property required: if type is OBJECT, then this contains the list of required keys + * @property items: if the type is ARRAY, then this contains a description of the objects in the + * array + */ +class Schema( + val name: String, + val description: String, + val format: String? = null, + val enum: List? = null, + val properties: Map>? = null, + val required: List? = null, + val items: Schema? = null, + val type: FunctionType, +) { + fun fromString(value: String?) = type.parse(value) + + companion object { + /** Registers a schema for an integer number */ + fun int(name: String, description: String) = + Schema(name = name, description = description, type = FunctionType.INTEGER) + + /** Registers a schema for a string */ + fun str(name: String, description: String) = + Schema(name = name, description = description, type = FunctionType.STRING) + + /** Registers a schema for a boolean */ + fun bool(name: String, description: String) = + Schema(name = name, description = description, type = FunctionType.BOOLEAN) + + /** Registers a schema for a floating point number */ + fun num(name: String, description: String) = + Schema(name = name, description = description, type = FunctionType.NUMBER) + + /** + * Registers a schema for a complex object. In a function it will be returned as a [JSONObject] + */ + fun obj(name: String, description: String) = + Schema(name = name, description = description, type = FunctionType.OBJECT) + + /** Registers a schema for an array */ + fun arr(name: String, description: String) = + Schema>(name = name, description = description, type = FunctionType.ARRAY) + + /** Registers a schema for an enum */ + fun enum(name: String, description: String, values: List) = + Schema( + name = name, + description = description, + format = "enum", + enum = values, + type = FunctionType.STRING, + ) + } +} + +@GenerativeBeta +fun defineFunction(name: String, description: String, function: suspend () -> JSONObject) = + NoParameterFunction(name, description, function) + +@GenerativeBeta +fun defineFunction( + name: String, + description: String, + arg1: Schema, + function: suspend (T) -> JSONObject, +) = OneParameterFunction(name, description, arg1, function) + +@GenerativeBeta +fun defineFunction( + name: String, + description: String, + arg1: Schema, + arg2: Schema, + function: suspend (T, U) -> JSONObject, +) = TwoParameterFunction(name, description, arg1, arg2, function) + +@GenerativeBeta +fun defineFunction( + name: String, + description: String, + arg1: Schema, + arg2: Schema, + arg3: Schema, + function: suspend (T, U, W) -> JSONObject, +) = ThreeParameterFunction(name, description, arg1, arg2, arg3, function) + +@GenerativeBeta +fun defineFunction( + name: String, + description: String, + arg1: Schema, + arg2: Schema, + arg3: Schema, + arg4: Schema, + function: suspend (T, U, W, Z) -> JSONObject, +) = FourParameterFunction(name, description, arg1, arg2, arg3, arg4, function) + +private fun FunctionCallPart.getArgOrThrow(param: Schema): T { + return param.fromString(args[param.name]) + ?: throw RuntimeException( + "Missing argument for parameter \"${param.name}\" for function \"$name\"" + ) +} diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt new file mode 100644 index 00000000..cb9ccbe1 --- /dev/null +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionParameter.kt @@ -0,0 +1,19 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.type + +class FunctionParameter(val name: String, val description: String, val type: FunctionType) {} diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerateContentResponse.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerateContentResponse.kt index 9e33ce16..9fe30329 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerateContentResponse.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerateContentResponse.kt @@ -32,6 +32,12 @@ class GenerateContentResponse( /** Convenience field representing the first text part in the response, if it exists. */ val text: String? by lazy { firstPartAs()?.text } + /** Convenience field representing the first text part in the response, if it exists. */ + val functionCall: FunctionCallPart? by lazy { firstPartAs() } + + /** Convenience field representing the first text part in the response, if it exists. */ + val functionResponse: FunctionResponsePart? by lazy { firstPartAs() } + private inline fun firstPartAs(): T? { if (candidates.isEmpty()) { warn("No candidates were found, but was asked to get a candidate.") diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerativeBeta.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerativeBeta.kt new file mode 100644 index 00000000..509f62e3 --- /dev/null +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerativeBeta.kt @@ -0,0 +1,22 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.type + +@RequiresOptIn(message = "This API is only available on the v1beta version") +@Retention(AnnotationRetention.BINARY) +@Target(AnnotationTarget.CLASS, AnnotationTarget.FUNCTION) +annotation class GenerativeBeta diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt index 4c54cc43..83f22953 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt @@ -17,6 +17,7 @@ package com.google.ai.client.generativeai.type import android.graphics.Bitmap +import org.json.JSONObject /** * Interface representing data sent to and received from requests. @@ -40,6 +41,12 @@ class ImagePart(val image: Bitmap) : Part /** Represents binary data with an associated MIME type sent to and received from requests. */ class BlobPart(val mimeType: String, val blob: ByteArray) : Part +/** Represents function call name and params received from requests. */ +class FunctionCallPart(val name: String, val args: Map) : Part + +/** Represents function call output to be returned to the model when it requests a function call */ +class FunctionResponsePart(val name: String, val response: JSONObject) : Part + /** @return The part as a [String] if it represents text, and null otherwise */ fun Part.asTextOrNull(): String? = (this as? TextPart)?.text diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt index e25561fb..a2470ef2 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/RequestOptions.kt @@ -27,10 +27,16 @@ import kotlin.time.toDuration * first response. * @property apiVersion the api endpoint to call. */ -class RequestOptions(val timeout: Duration, val apiVersion: String = "v1") { +class RequestOptions( + val timeout: Duration, + val apiVersion: String = "v1", +) { @JvmOverloads constructor( timeout: Long? = Long.MAX_VALUE, - apiVersion: String = "v1" - ) : this((timeout ?: Long.MAX_VALUE).toDuration(DurationUnit.MILLISECONDS), apiVersion) + apiVersion: String = "v1", + ) : this( + (timeout ?: Long.MAX_VALUE).toDuration(DurationUnit.MILLISECONDS), + apiVersion, + ) } diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt new file mode 100644 index 00000000..67467710 --- /dev/null +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt @@ -0,0 +1,28 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.type + +/** + * Contains a set of function declarations that the model has access to. These can be used to gather + * information, or complete tasks + * + * @param functionDeclarations The set of functions that this tool allows the model access to + */ +@OptIn(GenerativeBeta::class) +class Tool( + val functionDeclarations: List, +) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Type.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Type.kt new file mode 100644 index 00000000..c220031d --- /dev/null +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Type.kt @@ -0,0 +1,42 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.type + +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.jsonArray +import org.json.JSONObject + +/** + * Represents and passes the type information for an automated function call. + * + * @property name: the enum name of the type + * @property parse: the deserialization function + * @property T: the type of the object that this maps to in code. + */ +class FunctionType(val name: String, val parse: (String?) -> T?) { + companion object { + val STRING = FunctionType("STRING") { it } + val INTEGER = FunctionType("INTEGER") { it?.toLongOrNull() } + val NUMBER = FunctionType("NUMBER") { it?.toDoubleOrNull() } + val BOOLEAN = FunctionType("BOOLEAN") { it?.toBoolean() } + val ARRAY = + FunctionType>("ARRAY") { it -> + it?.let { Json.parseToJsonElement(it).jsonArray.map { element -> element.toString() } } + } + val OBJECT = FunctionType("OBJECT") { it?.let { JSONObject(it) } } + } +}