Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify function calling #176

Merged
merged 10 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .changes/generativeai/beef-collar-burn-aftermath.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"MAJOR","changes":["Simplify function calling and remove provided function execution."]}
3 changes: 2 additions & 1 deletion generativeai/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ android {
testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
consumerProguardFiles("consumer-rules.pro")

buildConfigField("String", "VERSION_NAME", "\"${project.version.toString()}\"")
buildConfigField("String", "VERSION_NAME", "\"${project.version}\"")
}

publishing {
Expand Down Expand Up @@ -85,6 +85,7 @@ dependencies {
implementation("com.google.guava:listenablefuture:1.0")
implementation("androidx.concurrent:concurrent-futures:1.2.0-alpha03")
implementation("androidx.concurrent:concurrent-futures-ktx:1.2.0-alpha03")
testImplementation("org.json:json:20210307") // Required for JSONObject to function in tests
testImplementation("junit:junit:4.13.2")
testImplementation("io.kotest:kotest-assertions-core:5.5.5")
testImplementation("io.kotest:kotest-assertions-core-jvm:5.5.5")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,29 +26,21 @@ import com.google.ai.client.generativeai.internal.util.toPublic
import com.google.ai.client.generativeai.type.Content
import com.google.ai.client.generativeai.type.CountTokensResponse
import com.google.ai.client.generativeai.type.FinishReason
import com.google.ai.client.generativeai.type.FourParameterFunction
import com.google.ai.client.generativeai.type.FunctionCallPart
import com.google.ai.client.generativeai.type.GenerateContentResponse
import com.google.ai.client.generativeai.type.GenerationConfig
import com.google.ai.client.generativeai.type.GoogleGenerativeAIException
import com.google.ai.client.generativeai.type.InvalidStateException
import com.google.ai.client.generativeai.type.NoParameterFunction
import com.google.ai.client.generativeai.type.OneParameterFunction
import com.google.ai.client.generativeai.type.PromptBlockedException
import com.google.ai.client.generativeai.type.RequestOptions
import com.google.ai.client.generativeai.type.ResponseStoppedException
import com.google.ai.client.generativeai.type.SafetySetting
import com.google.ai.client.generativeai.type.SerializationException
import com.google.ai.client.generativeai.type.ThreeParameterFunction
import com.google.ai.client.generativeai.type.Tool
import com.google.ai.client.generativeai.type.ToolConfig
import com.google.ai.client.generativeai.type.TwoParameterFunction
import com.google.ai.client.generativeai.type.content
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.map
import kotlinx.serialization.ExperimentalSerializationApi
import org.json.JSONObject

/**
* A facilitator for a given multimodal model (eg; Gemini).
Expand Down Expand Up @@ -199,36 +191,6 @@ internal constructor(
return countTokens(content { image(prompt) })
}

/**
* Executes a function requested by the model.
*
* @param functionCallPart A [FunctionCallPart] from the model, containing a function call and
* parameters
* @return The output of the requested function call
*/
suspend fun executeFunction(functionCallPart: FunctionCallPart): JSONObject {
if (tools == null) {
throw InvalidStateException("No registered tools")
}
val callable =
tools.flatMap { it.functionDeclarations }.firstOrNull { it.name == functionCallPart.name }
?: throw InvalidStateException("No registered function named ${functionCallPart.name}")
return when (callable) {
is NoParameterFunction -> callable.execute()
is OneParameterFunction<*> ->
(callable as OneParameterFunction<Any?>).execute(functionCallPart)
is TwoParameterFunction<*, *> ->
(callable as TwoParameterFunction<Any?, Any?>).execute(functionCallPart)
is ThreeParameterFunction<*, *, *> ->
(callable as ThreeParameterFunction<Any?, Any?, Any?>).execute(functionCallPart)
is FourParameterFunction<*, *, *, *> ->
(callable as FourParameterFunction<Any?, Any?, Any?, Any?>).execute(functionCallPart)
else -> {
throw RuntimeException("UNREACHABLE")
}
}
}

private fun constructRequest(vararg prompt: Content) =
GenerateContentRequest(
modelName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ internal fun com.google.ai.client.generativeai.type.Part.toInternal(): Part {
is com.google.ai.client.generativeai.type.BlobPart ->
BlobPart(Blob(mimeType, Base64.encodeToString(blob, BASE_64_FLAGS)))
is com.google.ai.client.generativeai.type.FunctionCallPart ->
FunctionCallPart(FunctionCall(name, args.orEmpty()))
FunctionCallPart(FunctionCall(name, args))
is com.google.ai.client.generativeai.type.FunctionResponsePart ->
FunctionResponsePart(FunctionResponse(name, response.toInternal()))
is com.google.ai.client.generativeai.type.FileDataPart ->
Expand Down Expand Up @@ -147,8 +147,8 @@ internal fun FunctionDeclaration.toInternal() =
name,
description,
Schema(
properties = getParameters().associate { it.name to it.toInternal() },
required = getParameters().map { it.name },
properties = parameters.associate { it.name to it.toInternal() },
required = requiredParameters,
type = "OBJECT",
nullable = false,
rlazo marked this conversation as resolved.
Show resolved Hide resolved
),
Expand Down Expand Up @@ -196,10 +196,7 @@ internal fun Part.toPublic(): com.google.ai.client.generativeai.type.Part {
}
}
is FunctionCallPart ->
com.google.ai.client.generativeai.type.FunctionCallPart(
functionCall.name,
functionCall.args.orEmpty(),
)
com.google.ai.client.generativeai.type.FunctionCallPart(functionCall.name, functionCall.args)
is FunctionResponsePart ->
com.google.ai.client.generativeai.type.FunctionResponsePart(
functionResponse.name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,144 +19,24 @@ package com.google.ai.client.generativeai.type
import org.json.JSONObject

/**
* A declared function, including implementation, that a model can be given access to in order to
* gain info or complete tasks.
*
* @property name The name of the function call, this should be clear and descriptive for the model
* @property description A description of what the function does and its output.
* @property function the function implementation
*/
class NoParameterFunction(
name: String,
description: String,
val function: suspend () -> JSONObject,
) : FunctionDeclaration(name, description) {
override fun getParameters() = listOf<Schema<Any>>()

suspend fun execute() = function()

override suspend fun execute(part: FunctionCallPart) = function()
}

/**
* A declared function, including implementation, that a model can be given access to in order to
* gain info or complete tasks.
*
* @property name The name of the function call, this should be clear and descriptive for the model
* @property description A description of what the function does and its output.
* @property param A description of the first function parameter
* @property function the function implementation
*/
class OneParameterFunction<T>(
name: String,
description: String,
val param: Schema<T>,
val function: suspend (T) -> JSONObject,
) : FunctionDeclaration(name, description) {
override fun getParameters() = listOf(param)

override suspend fun execute(part: FunctionCallPart): JSONObject {
val arg1 = part.getArgOrThrow(param)
return function(arg1)
}
}

/**
* A declared function, including implementation, that a model can be given access to in order to
* gain info or complete tasks.
*
* @property name The name of the function call, this should be clear and descriptive for the model
* @property description A description of what the function does and its output.
* @property param1 A description of the first function parameter
* @property param2 A description of the second function parameter
* @property function the function implementation
*/
class TwoParameterFunction<T, U>(
name: String,
description: String,
val param1: Schema<T>,
val param2: Schema<U>,
val function: suspend (T, U) -> JSONObject,
) : FunctionDeclaration(name, description) {
override fun getParameters() = listOf(param1, param2)

override suspend fun execute(part: FunctionCallPart): JSONObject {
val arg1 = part.getArgOrThrow(param1)
val arg2 = part.getArgOrThrow(param2)
return function(arg1, arg2)
}
}

/**
* A declared function, including implementation, that a model can be given access to in order to
* gain info or complete tasks.
*
* @property name The name of the function call, this should be clear and descriptive for the model
* @property description A description of what the function does and its output.
* @property param1 A description of the first function parameter
* @property param2 A description of the second function parameter
* @property param3 A description of the third function parameter
* @property function the function implementation
*/
class ThreeParameterFunction<T, U, V>(
name: String,
description: String,
val param1: Schema<T>,
val param2: Schema<U>,
val param3: Schema<V>,
val function: suspend (T, U, V) -> JSONObject,
) : FunctionDeclaration(name, description) {
override fun getParameters() = listOf(param1, param2, param3)

override suspend fun execute(part: FunctionCallPart): JSONObject {
val arg1 = part.getArgOrThrow(param1)
val arg2 = part.getArgOrThrow(param2)
val arg3 = part.getArgOrThrow(param3)
return function(arg1, arg2, arg3)
}
}

/**
* A declared function, including implementation, that a model can be given access to in order to
* gain info or complete tasks.
* Representation of a function that a model can invoke.
*
* @property name The name of the function call, this should be clear and descriptive for the model
* @property description A description of what the function does and its output.
* @property param1 A description of the first function parameter
* @property param2 A description of the second function parameter
* @property param3 A description of the third function parameter
* @property param4 A description of the fourth function parameter
* @property function the function implementation
* @see defineFunction
*/
class FourParameterFunction<T, U, V, W>(
name: String,
description: String,
val param1: Schema<T>,
val param2: Schema<U>,
val param3: Schema<V>,
val param4: Schema<W>,
val function: suspend (T, U, V, W) -> JSONObject,
) : FunctionDeclaration(name, description) {
override fun getParameters() = listOf(param1, param2, param3, param4)

override suspend fun execute(part: FunctionCallPart): JSONObject {
val arg1 = part.getArgOrThrow(param1)
val arg2 = part.getArgOrThrow(param2)
val arg3 = part.getArgOrThrow(param3)
val arg4 = part.getArgOrThrow(param4)
return function(arg1, arg2, arg3, arg4)
}
}

abstract class FunctionDeclaration(val name: String, val description: String) {
abstract fun getParameters(): List<Schema<out Any?>>

abstract suspend fun execute(part: FunctionCallPart): JSONObject
}
class FunctionDeclaration(
val name: String,
val description: String,
val parameters: List<Schema<*>>,
val requiredParameters: List<String>,
)

/**
* Represents a parameter for a declared function
*
* ```
* val currencyFrom = Schema.str("currencyFrom", "The currency to convert from.")
* ```
*
* @property name: The name of the parameter
* @property description: The description of what the parameter should contain or represent
* @property format: format information for the parameter, this can include bitlength in the case of
Expand All @@ -180,6 +60,21 @@ class Schema<T>(
val items: Schema<out Any>? = null,
val type: FunctionType<T>,
) {

/**
* Attempts to parse a string to the type [T] assigned to this schema.
*
* Will return null if the provided string is null. May also return null if the provided string is
* not a valid string of the expected type; but this should not be relied upon, as it may throw in
* certain scenarios (eg; the type is an object or array, and the string is not valid json).
*
* ```
* val currenciesSchema = Schema.arr("currencies", "The currencies available to use.")
* val currencies: List<String> = currenciesSchema.fromString("""
* ["USD", "EUR", "CAD", "GBP", "JPY"]
* """)
* ```
*/
fun fromString(value: String?) = type.parse(value)

companion object {
Expand Down Expand Up @@ -259,46 +154,31 @@ class Schema<T>(
}
}

fun defineFunction(name: String, description: String, function: suspend () -> JSONObject) =
NoParameterFunction(name, description, function)

fun <T> defineFunction(
name: String,
description: String,
arg1: Schema<T>,
function: suspend (T) -> JSONObject,
) = OneParameterFunction(name, description, arg1, function)

fun <T, U> defineFunction(
name: String,
description: String,
arg1: Schema<T>,
arg2: Schema<U>,
function: suspend (T, U) -> JSONObject,
) = TwoParameterFunction(name, description, arg1, arg2, function)

fun <T, U, W> defineFunction(
name: String,
description: String,
arg1: Schema<T>,
arg2: Schema<U>,
arg3: Schema<W>,
function: suspend (T, U, W) -> JSONObject,
) = ThreeParameterFunction(name, description, arg1, arg2, arg3, function)

fun <T, U, W, Z> defineFunction(
/**
* A declared function, including implementation, that a model can be given access to in order to
* gain info or complete tasks.
*
* ```
* val getExchangeRate = defineFunction(
* name = "getExchangeRate",
* description = "Get the exchange rate for currencies between countries.",
* parameters = listOf(
* Schema.str("currencyFrom", "The currency to convert from."),
* Schema.str("currencyTo", "The currency to convert to.")
* ),
* requiredParameters = listOf("currencyFrom", "currencyTo")
* )
* ```
*
* @param name The name of the function call, this should be clear and descriptive for the model.
* @param description A description of what the function does and its output.
* @param parameters A list of parameters that the function accepts.
* @param requiredParameters A list of parameters that the function requires to run.
* @see Schema
*/
fun defineFunction(
name: String,
description: String,
arg1: Schema<T>,
arg2: Schema<U>,
arg3: Schema<W>,
arg4: Schema<Z>,
function: suspend (T, U, W, Z) -> JSONObject,
) = FourParameterFunction(name, description, arg1, arg2, arg3, arg4, function)

private fun <T> FunctionCallPart.getArgOrThrow(param: Schema<T>): T {
return param.fromString(args[param.name])
?: throw RuntimeException(
"Missing argument for parameter \"${param.name}\" for function \"$name\""
)
}
parameters: List<Schema<*>> = emptyList(),
requiredParameters: List<String> = emptyList(),
) = FunctionDeclaration(name, description, parameters, requiredParameters)
Loading
Loading