Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Function calling (common only) #97

Merged
merged 9 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.google.ai.client.generativeai.common

import com.google.ai.client.generativeai.common.client.GenerationConfig
import com.google.ai.client.generativeai.common.client.Tool
import com.google.ai.client.generativeai.common.shared.Content
import com.google.ai.client.generativeai.common.shared.SafetySetting
import kotlinx.serialization.SerialName
Expand All @@ -30,6 +31,7 @@ data class GenerateContentRequest(
val contents: List<Content>,
@SerialName("safety_settings") val safetySettings: List<SafetySetting>? = null,
@SerialName("generation_config") val generationConfig: GenerationConfig? = null,
val tools: List<Tool>? = null,
) : Request

@Serializable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,16 @@ import kotlin.time.toDuration
class RequestOptions(
val timeout: Duration,
val apiVersion: String = "v1",
val endpoint: String = "https://generativelanguage.googleapis.com"
val endpoint: String = "https://generativelanguage.googleapis.com",
) {
@JvmOverloads
constructor(
timeout: Long? = HttpTimeout.INFINITE_TIMEOUT_MS,
apiVersion: String = "v1",
endpoint: String = "https://generativelanguage.googleapis.com"
endpoint: String = "https://generativelanguage.googleapis.com",
) : this(
(timeout ?: HttpTimeout.INFINITE_TIMEOUT_MS).toDuration(DurationUnit.MILLISECONDS),
apiVersion,
endpoint
endpoint,
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,25 @@ data class GenerationConfig(
@SerialName("top_k") val topK: Int?,
@SerialName("candidate_count") val candidateCount: Int?,
@SerialName("max_output_tokens") val maxOutputTokens: Int?,
@SerialName("stop_sequences") val stopSequences: List<String>?
@SerialName("stop_sequences") val stopSequences: List<String>?,
)

@Serializable data class Tool(val functionDeclarations: List<FunctionDeclaration>)

@Serializable
data class FunctionDeclaration(
val name: String,
val description: String,
val parameters: Schema,
)

@Serializable
data class Schema(
val type: String,
val description: String? = null,
val format: String? = null,
val enum: List<String>? = null,
val properties: Map<String, Schema>? = null,
val required: List<String>? = null,
val items: Schema? = null,
)
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import kotlinx.serialization.Serializable
import kotlinx.serialization.SerializationException
import kotlinx.serialization.json.JsonContentPolymorphicSerializer
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.jsonObject

object HarmCategorySerializer :
Expand All @@ -52,6 +53,14 @@ data class Content(@EncodeDefault val role: String? = "user", val parts: List<Pa

@Serializable data class BlobPart(@SerialName("inline_data") val inlineData: Blob) : Part

@Serializable data class FunctionCallPart(val functionCall: FunctionCall) : Part

@Serializable data class FunctionResponsePart(val functionResponse: FunctionResponse) : Part

@Serializable data class FunctionResponse(val name: String, val response: JsonObject)

@Serializable data class FunctionCall(val name: String, val args: Map<String, String>)

@Serializable data class FileDataPart(@SerialName("file_data") val fileData: FileData) : Part

@Serializable
Expand Down Expand Up @@ -83,6 +92,8 @@ object PartSerializer : JsonContentPolymorphicSerializer<Part>(Part::class) {
val jsonObject = element.jsonObject
return when {
"text" in jsonObject -> TextPart.serializer()
"functionCall" in jsonObject -> FunctionCallPart.serializer()
"functionResponse" in jsonObject -> FunctionResponsePart.serializer()
"inline_data" in jsonObject -> BlobPart.serializer()
"file_data" in jsonObject -> FileDataPart.serializer()
else -> throw SerializationException("Unknown Part type")
Expand Down
Loading