Skip to content

Commit

Permalink
Function calling (common only) (#97)
Browse files Browse the repository at this point in the history
Based on work from
https://github.com/google/generative-ai-android/tree/davidmotson.auto_function_split

---------

Co-authored-by: David Motsonashvili <[email protected]>
Co-authored-by: David Motsonashvili <[email protected]>
  • Loading branch information
3 people authored Mar 26, 2024
1 parent 849f432 commit a2b71df
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.google.ai.client.generativeai.common

import com.google.ai.client.generativeai.common.client.GenerationConfig
import com.google.ai.client.generativeai.common.client.Tool
import com.google.ai.client.generativeai.common.shared.Content
import com.google.ai.client.generativeai.common.shared.SafetySetting
import kotlinx.serialization.SerialName
Expand All @@ -30,6 +31,7 @@ data class GenerateContentRequest(
val contents: List<Content>,
@SerialName("safety_settings") val safetySettings: List<SafetySetting>? = null,
@SerialName("generation_config") val generationConfig: GenerationConfig? = null,
val tools: List<Tool>? = null,
) : Request

@Serializable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,16 @@ import kotlin.time.toDuration
class RequestOptions(
val timeout: Duration,
val apiVersion: String = "v1",
val endpoint: String = "https://generativelanguage.googleapis.com"
val endpoint: String = "https://generativelanguage.googleapis.com",
) {
@JvmOverloads
constructor(
timeout: Long? = HttpTimeout.INFINITE_TIMEOUT_MS,
apiVersion: String = "v1",
endpoint: String = "https://generativelanguage.googleapis.com"
endpoint: String = "https://generativelanguage.googleapis.com",
) : this(
(timeout ?: HttpTimeout.INFINITE_TIMEOUT_MS).toDuration(DurationUnit.MILLISECONDS),
apiVersion,
endpoint
endpoint,
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,25 @@ data class GenerationConfig(
@SerialName("top_k") val topK: Int?,
@SerialName("candidate_count") val candidateCount: Int?,
@SerialName("max_output_tokens") val maxOutputTokens: Int?,
@SerialName("stop_sequences") val stopSequences: List<String>?
@SerialName("stop_sequences") val stopSequences: List<String>?,
)

@Serializable data class Tool(val functionDeclarations: List<FunctionDeclaration>)

@Serializable
data class FunctionDeclaration(
val name: String,
val description: String,
val parameters: Schema,
)

@Serializable
data class Schema(
val type: String,
val description: String? = null,
val format: String? = null,
val enum: List<String>? = null,
val properties: Map<String, Schema>? = null,
val required: List<String>? = null,
val items: Schema? = null,
)
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import kotlinx.serialization.Serializable
import kotlinx.serialization.SerializationException
import kotlinx.serialization.json.JsonContentPolymorphicSerializer
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.jsonObject

object HarmCategorySerializer :
Expand All @@ -52,6 +53,14 @@ data class Content(@EncodeDefault val role: String? = "user", val parts: List<Pa

@Serializable data class BlobPart(@SerialName("inline_data") val inlineData: Blob) : Part

@Serializable data class FunctionCallPart(val functionCall: FunctionCall) : Part

@Serializable data class FunctionResponsePart(val functionResponse: FunctionResponse) : Part

@Serializable data class FunctionResponse(val name: String, val response: JsonObject)

@Serializable data class FunctionCall(val name: String, val args: Map<String, String>)

@Serializable data class FileDataPart(@SerialName("file_data") val fileData: FileData) : Part

@Serializable
Expand Down Expand Up @@ -83,6 +92,8 @@ object PartSerializer : JsonContentPolymorphicSerializer<Part>(Part::class) {
val jsonObject = element.jsonObject
return when {
"text" in jsonObject -> TextPart.serializer()
"functionCall" in jsonObject -> FunctionCallPart.serializer()
"functionResponse" in jsonObject -> FunctionResponsePart.serializer()
"inline_data" in jsonObject -> BlobPart.serializer()
"file_data" in jsonObject -> FileDataPart.serializer()
else -> throw SerializationException("Unknown Part type")
Expand Down

0 comments on commit a2b71df

Please sign in to comment.