diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt index cb2fb478..c8fc3fb6 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt @@ -59,6 +59,7 @@ data class Schema( val type: String, val description: String? = null, val format: String? = null, + val nullable: Boolean? = false, val enum: List? = null, val properties: Map? = null, val required: List? = null, diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt index 29882560..64e816e2 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt @@ -59,7 +59,7 @@ data class Content(@EncodeDefault val role: String? = "user", val parts: List) +@Serializable data class FunctionCall(val name: String, val args: Map) @Serializable data class FileDataPart(@SerialName("file_data") val fileData: FileData) : Part diff --git a/common/src/test/java/com/google/ai/client/generativeai/common/UnarySnapshotTests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/UnarySnapshotTests.kt index a3f8d77f..13ec5428 100644 --- a/common/src/test/java/com/google/ai/client/generativeai/common/UnarySnapshotTests.kt +++ b/common/src/test/java/com/google/ai/client/generativeai/common/UnarySnapshotTests.kt @@ -20,6 +20,7 @@ import com.google.ai.client.generativeai.common.server.BlockReason import com.google.ai.client.generativeai.common.server.FinishReason import com.google.ai.client.generativeai.common.server.HarmProbability import com.google.ai.client.generativeai.common.server.HarmSeverity +import com.google.ai.client.generativeai.common.shared.FunctionCallPart import com.google.ai.client.generativeai.common.shared.HarmCategory import com.google.ai.client.generativeai.common.shared.TextPart import com.google.ai.client.generativeai.common.util.goldenUnaryFile @@ -301,4 +302,15 @@ internal class UnarySnapshotTests { } } } + + @Test + fun `function call contains null param`() = + goldenUnaryFile("success-function-call-null.json") { + withTimeout(testTimeout) { + val response = apiController.generateContent(textGenerateContentRequest("prompt")) + val callPart = (response.candidates!!.first().content!!.parts.first() as FunctionCallPart) + + callPart.functionCall.args["season"] shouldBe null + } + } } diff --git a/common/src/test/resources/golden-files/unary/success-function-call-null.json b/common/src/test/resources/golden-files/unary/success-function-call-null.json new file mode 100644 index 00000000..14801eef --- /dev/null +++ b/common/src/test/resources/golden-files/unary/success-function-call-null.json @@ -0,0 +1,45 @@ +{ + "candidates": [ + { + "content": { + "parts": [ + { + "functionCall": { + "name": "functionName", + "args": { + "original_title": "String", + "season": null + } + } + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 774, + "candidatesTokenCount": 4176, + "totalTokenCount": 4950 + } +} diff --git a/generativeai/src/commonMain/kotlin/dev/shreyaspatil/ai/client/generativeai/type/Part.kt b/generativeai/src/commonMain/kotlin/dev/shreyaspatil/ai/client/generativeai/type/Part.kt index 6dd316e9..ea65e4eb 100644 --- a/generativeai/src/commonMain/kotlin/dev/shreyaspatil/ai/client/generativeai/type/Part.kt +++ b/generativeai/src/commonMain/kotlin/dev/shreyaspatil/ai/client/generativeai/type/Part.kt @@ -47,7 +47,7 @@ class FileDataPart(val uri: String, val mimeType: String) : Part fun Part.asFileDataPartOrNull(): FileDataPart? = this as? FileDataPart /** Represents function call name and params received from requests. */ -class FunctionCallPart(val name: String, val args: Map) : Part +class FunctionCallPart(val name: String, val args: Map) : Part /** Represents function call output to be returned to the model when it requests a function call */ class FunctionResponsePart(val name: String, val response: JsonObject) : Part diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt index dadc2611..dfa8c3c0 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt @@ -150,6 +150,7 @@ internal fun FunctionDeclaration.toInternal() = properties = getParameters().associate { it.name to it.toInternal() }, required = getParameters().map { it.name }, type = "OBJECT", + nullable = false, ), ) @@ -158,6 +159,7 @@ internal fun com.google.ai.client.generativeai.type.Schema.toInternal(): type.name, description, format, + nullable, enum, properties?.mapValues { it.value.toInternal() }, required, diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt index 80569ad7..21d02b08 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt @@ -173,6 +173,7 @@ class Schema( val name: String, val description: String, val format: String? = null, + val nullable: Boolean? = null, val enum: List? = null, val properties: Map>? = null, val required: List? = null, @@ -184,19 +185,39 @@ class Schema( companion object { /** Registers a schema for an integer number */ fun int(name: String, description: String) = - Schema(name = name, description = description, type = FunctionType.INTEGER) + Schema( + name = name, + description = description, + type = FunctionType.INTEGER, + nullable = false, + ) /** Registers a schema for a string */ fun str(name: String, description: String) = - Schema(name = name, description = description, type = FunctionType.STRING) + Schema( + name = name, + description = description, + type = FunctionType.STRING, + nullable = false, + ) /** Registers a schema for a boolean */ fun bool(name: String, description: String) = - Schema(name = name, description = description, type = FunctionType.BOOLEAN) + Schema( + name = name, + description = description, + type = FunctionType.BOOLEAN, + nullable = false, + ) /** Registers a schema for a floating point number */ fun num(name: String, description: String) = - Schema(name = name, description = description, type = FunctionType.NUMBER) + Schema( + name = name, + description = description, + type = FunctionType.NUMBER, + nullable = false, + ) /** * Registers a schema for a complex object. In a function it will be returned as a [JSONObject] @@ -208,11 +229,17 @@ class Schema( type = FunctionType.OBJECT, required = contents.map { it.name }, properties = contents.associateBy { it.name }.toMap(), + nullable = false, ) /** Registers a schema for an array */ fun arr(name: String, description: String) = - Schema>(name = name, description = description, type = FunctionType.ARRAY) + Schema>( + name = name, + description = description, + type = FunctionType.ARRAY, + nullable = false, + ) /** Registers a schema for an enum */ fun enum(name: String, description: String, values: List) = @@ -222,6 +249,7 @@ class Schema( format = "enum", enum = values, type = FunctionType.STRING, + nullable = false, ) } }