diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt index faed8ee2..ebac13fd 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/APIController.kt @@ -240,7 +240,7 @@ private suspend fun validateResponse(response: HttpResponse) { if (message.contains("quota")) { throw QuotaExceededException(message) } - if (error.details.any { "SERVICE_DISABLED" == it.reason }) { + if (error.details?.any { "SERVICE_DISABLED" == it.reason } == true) { throw ServiceDisabledException(message) } throw ServerException(message) diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt index cfc1e9d7..cb2fb478 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt @@ -30,6 +30,7 @@ data class GenerationConfig( @SerialName("response_mime_type") val responseMimeType: String? = null, @SerialName("presence_penalty") val presencePenalty: Float? = null, @SerialName("frequency_penalty") val frequencyPenalty: Float? = null, + @SerialName("response_schema") val responseSchema: Schema? = null, ) @Serializable data class Tool(val functionDeclarations: List) diff --git a/generativeai/src/commonMain/kotlin/dev/shreyaspatil/ai/client/generativeai/type/GenerationConfig.kt b/generativeai/src/commonMain/kotlin/dev/shreyaspatil/ai/client/generativeai/type/GenerationConfig.kt index b9bb513a..bba0d3a1 100644 --- a/generativeai/src/commonMain/kotlin/dev/shreyaspatil/ai/client/generativeai/type/GenerationConfig.kt +++ b/generativeai/src/commonMain/kotlin/dev/shreyaspatil/ai/client/generativeai/type/GenerationConfig.kt @@ -39,6 +39,7 @@ private constructor( val maxOutputTokens: Int?, val stopSequences: List?, val responseMimeType: String?, + val responseSchema: Schema<*>?, ) { class Builder { @@ -49,6 +50,7 @@ private constructor( @JvmField var maxOutputTokens: Int? = null @JvmField var stopSequences: List? = null @JvmField var responseMimeType: String? = null + @JvmField var responseSchema: Schema<*>? = null fun build() = GenerationConfig( @@ -59,6 +61,7 @@ private constructor( maxOutputTokens = maxOutputTokens, stopSequences = stopSequences, responseMimeType = responseMimeType, + responseSchema = responseSchema, ) } diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt index 11eb1368..dadc2611 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt @@ -99,6 +99,7 @@ internal fun com.google.ai.client.generativeai.type.GenerationConfig.toInternal( maxOutputTokens = maxOutputTokens, stopSequences = stopSequences, responseMimeType = responseMimeType, + responseSchema = responseSchema?.toInternal(), ) internal fun com.google.ai.client.generativeai.type.HarmCategory.toInternal() = diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt new file mode 100644 index 00000000..80569ad7 --- /dev/null +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt @@ -0,0 +1,271 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.type + +import org.json.JSONObject + +/** + * A declared function, including implementation, that a model can be given access to in order to + * gain info or complete tasks. + * + * @property name The name of the function call, this should be clear and descriptive for the model + * @property description A description of what the function does and its output. + * @property function the function implementation + */ +class NoParameterFunction( + name: String, + description: String, + val function: suspend () -> JSONObject, +) : FunctionDeclaration(name, description) { + override fun getParameters() = listOf>() + + suspend fun execute() = function() + + override suspend fun execute(part: FunctionCallPart) = function() +} + +/** + * A declared function, including implementation, that a model can be given access to in order to + * gain info or complete tasks. + * + * @property name The name of the function call, this should be clear and descriptive for the model + * @property description A description of what the function does and its output. + * @property param A description of the first function parameter + * @property function the function implementation + */ +class OneParameterFunction( + name: String, + description: String, + val param: Schema, + val function: suspend (T) -> JSONObject, +) : FunctionDeclaration(name, description) { + override fun getParameters() = listOf(param) + + override suspend fun execute(part: FunctionCallPart): JSONObject { + val arg1 = part.getArgOrThrow(param) + return function(arg1) + } +} + +/** + * A declared function, including implementation, that a model can be given access to in order to + * gain info or complete tasks. + * + * @property name The name of the function call, this should be clear and descriptive for the model + * @property description A description of what the function does and its output. + * @property param1 A description of the first function parameter + * @property param2 A description of the second function parameter + * @property function the function implementation + */ +class TwoParameterFunction( + name: String, + description: String, + val param1: Schema, + val param2: Schema, + val function: suspend (T, U) -> JSONObject, +) : FunctionDeclaration(name, description) { + override fun getParameters() = listOf(param1, param2) + + override suspend fun execute(part: FunctionCallPart): JSONObject { + val arg1 = part.getArgOrThrow(param1) + val arg2 = part.getArgOrThrow(param2) + return function(arg1, arg2) + } +} + +/** + * A declared function, including implementation, that a model can be given access to in order to + * gain info or complete tasks. + * + * @property name The name of the function call, this should be clear and descriptive for the model + * @property description A description of what the function does and its output. + * @property param1 A description of the first function parameter + * @property param2 A description of the second function parameter + * @property param3 A description of the third function parameter + * @property function the function implementation + */ +class ThreeParameterFunction( + name: String, + description: String, + val param1: Schema, + val param2: Schema, + val param3: Schema, + val function: suspend (T, U, V) -> JSONObject, +) : FunctionDeclaration(name, description) { + override fun getParameters() = listOf(param1, param2, param3) + + override suspend fun execute(part: FunctionCallPart): JSONObject { + val arg1 = part.getArgOrThrow(param1) + val arg2 = part.getArgOrThrow(param2) + val arg3 = part.getArgOrThrow(param3) + return function(arg1, arg2, arg3) + } +} + +/** + * A declared function, including implementation, that a model can be given access to in order to + * gain info or complete tasks. + * + * @property name The name of the function call, this should be clear and descriptive for the model + * @property description A description of what the function does and its output. + * @property param1 A description of the first function parameter + * @property param2 A description of the second function parameter + * @property param3 A description of the third function parameter + * @property param4 A description of the fourth function parameter + * @property function the function implementation + */ +class FourParameterFunction( + name: String, + description: String, + val param1: Schema, + val param2: Schema, + val param3: Schema, + val param4: Schema, + val function: suspend (T, U, V, W) -> JSONObject, +) : FunctionDeclaration(name, description) { + override fun getParameters() = listOf(param1, param2, param3, param4) + + override suspend fun execute(part: FunctionCallPart): JSONObject { + val arg1 = part.getArgOrThrow(param1) + val arg2 = part.getArgOrThrow(param2) + val arg3 = part.getArgOrThrow(param3) + val arg4 = part.getArgOrThrow(param4) + return function(arg1, arg2, arg3, arg4) + } +} + +abstract class FunctionDeclaration(val name: String, val description: String) { + abstract fun getParameters(): List> + + abstract suspend fun execute(part: FunctionCallPart): JSONObject +} + +/** + * Represents a parameter for a declared function + * + * @property name: The name of the parameter + * @property description: The description of what the parameter should contain or represent + * @property format: format information for the parameter, this can include bitlength in the case of + * int/float or keywords like "enum" for the string type + * @property enum: contains the enum values for a string enum + * @property type: contains the type info and parser + * @property properties: if type is OBJECT, then this contains the description of the fields of the + * object by name + * @property required: if type is OBJECT, then this contains the list of required keys + * @property items: if the type is ARRAY, then this contains a description of the objects in the + * array + */ +class Schema( + val name: String, + val description: String, + val format: String? = null, + val enum: List? = null, + val properties: Map>? = null, + val required: List? = null, + val items: Schema? = null, + val type: FunctionType, +) { + fun fromString(value: String?) = type.parse(value) + + companion object { + /** Registers a schema for an integer number */ + fun int(name: String, description: String) = + Schema(name = name, description = description, type = FunctionType.INTEGER) + + /** Registers a schema for a string */ + fun str(name: String, description: String) = + Schema(name = name, description = description, type = FunctionType.STRING) + + /** Registers a schema for a boolean */ + fun bool(name: String, description: String) = + Schema(name = name, description = description, type = FunctionType.BOOLEAN) + + /** Registers a schema for a floating point number */ + fun num(name: String, description: String) = + Schema(name = name, description = description, type = FunctionType.NUMBER) + + /** + * Registers a schema for a complex object. In a function it will be returned as a [JSONObject] + */ + fun obj(name: String, description: String, vararg contents: Schema) = + Schema( + name = name, + description = description, + type = FunctionType.OBJECT, + required = contents.map { it.name }, + properties = contents.associateBy { it.name }.toMap(), + ) + + /** Registers a schema for an array */ + fun arr(name: String, description: String) = + Schema>(name = name, description = description, type = FunctionType.ARRAY) + + /** Registers a schema for an enum */ + fun enum(name: String, description: String, values: List) = + Schema( + name = name, + description = description, + format = "enum", + enum = values, + type = FunctionType.STRING, + ) + } +} + +fun defineFunction(name: String, description: String, function: suspend () -> JSONObject) = + NoParameterFunction(name, description, function) + +fun defineFunction( + name: String, + description: String, + arg1: Schema, + function: suspend (T) -> JSONObject, +) = OneParameterFunction(name, description, arg1, function) + +fun defineFunction( + name: String, + description: String, + arg1: Schema, + arg2: Schema, + function: suspend (T, U) -> JSONObject, +) = TwoParameterFunction(name, description, arg1, arg2, function) + +fun defineFunction( + name: String, + description: String, + arg1: Schema, + arg2: Schema, + arg3: Schema, + function: suspend (T, U, W) -> JSONObject, +) = ThreeParameterFunction(name, description, arg1, arg2, arg3, function) + +fun defineFunction( + name: String, + description: String, + arg1: Schema, + arg2: Schema, + arg3: Schema, + arg4: Schema, + function: suspend (T, U, W, Z) -> JSONObject, +) = FourParameterFunction(name, description, arg1, arg2, arg3, arg4, function) + +private fun FunctionCallPart.getArgOrThrow(param: Schema): T { + return param.fromString(args[param.name]) + ?: throw RuntimeException( + "Missing argument for parameter \"${param.name}\" for function \"$name\"" + ) +}