diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt index 72d5e6c3..ea0cefc2 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt @@ -28,12 +28,12 @@ import com.google.ai.client.generativeai.type.CountTokensResponse import com.google.ai.client.generativeai.type.FinishReason import com.google.ai.client.generativeai.type.FourParameterFunction import com.google.ai.client.generativeai.type.FunctionCallPart -import com.google.ai.client.generativeai.type.FunctionParameter import com.google.ai.client.generativeai.type.GenerateContentResponse import com.google.ai.client.generativeai.type.GenerationConfig import com.google.ai.client.generativeai.type.GoogleGenerativeAIException import com.google.ai.client.generativeai.type.NoParameterFunction import com.google.ai.client.generativeai.type.OneParameterFunction +import com.google.ai.client.generativeai.type.ParameterDeclaration import com.google.ai.client.generativeai.type.PromptBlockedException import com.google.ai.client.generativeai.type.ResponseStoppedException import com.google.ai.client.generativeai.type.SafetySetting @@ -182,71 +182,34 @@ internal constructor( /** * Executes a function request by the model. * - * @param call A [FunctionCallPart] from the model, containing a function call and parameters + * @param functionCallPart A [FunctionCallPart] from the model, containing a function call and + * parameters * @return The output of the requested function call */ @BetaGenAiAPI - suspend fun executeFunction(call: FunctionCallPart): String { + fun executeFunction(functionCallPart: FunctionCallPart): String { if (tools == null) { throw RuntimeException("No registered tools") } - val tool = tools.first { it.functionDeclarations.any { it.name == call.name } } - val declaration = - tool.functionDeclarations.firstOrNull() { it.name == call.name } - ?: throw RuntimeException("No registered function named ${call.name}") - return when (declaration) { - is NoParameterFunction -> { - declaration.function.invoke() - } - is OneParameterFunction<*> -> { - declaration - .let { declaration as OneParameterFunction } - .let { - val param1 = getParamOrThrow(it.param, call) - it.function.invoke(param1) - } - } - is TwoParameterFunction<*, *> -> { - declaration - .let { declaration as TwoParameterFunction } - .let { - val param1 = getParamOrThrow(it.param1, call) - val param2 = getParamOrThrow(it.param2, call) - it.function.invoke(param1, param2) - } - } - is ThreeParameterFunction<*, *, *> -> { - declaration - .let { declaration as ThreeParameterFunction } - .let { - val param1 = getParamOrThrow(it.param1, call) - val param2 = getParamOrThrow(it.param2, call) - val param3 = getParamOrThrow(it.param3, call) - it.function.invoke(param1, param2, param3) - } - } - is FourParameterFunction<*, *, *, *> -> { - declaration - .let { declaration as FourParameterFunction } - .let { - val param1 = getParamOrThrow(it.param1, call) - val param2 = getParamOrThrow(it.param2, call) - val param3 = getParamOrThrow(it.param3, call) - val param4 = getParamOrThrow(it.param4, call) - it.function.invoke(param1, param2, param3, param4) - } - } + val tool = tools.first { it.functionDeclarations.any { it.name == functionCallPart.name } } + val callable = + tool.functionDeclarations.firstOrNull() { it.name == functionCallPart.name } + ?: throw RuntimeException("No registered function named ${functionCallPart.name}") + return when (callable) { + is NoParameterFunction -> callable() + is OneParameterFunction<*> -> (callable as OneParameterFunction)(functionCallPart) + is TwoParameterFunction<*, *> -> + (callable as TwoParameterFunction)(functionCallPart) + is ThreeParameterFunction<*, *, *> -> + (callable as ThreeParameterFunction)(functionCallPart) + is FourParameterFunction<*, *, *, *> -> + (callable as FourParameterFunction)(functionCallPart) else -> { throw RuntimeException("UNREACHABLE") } } } - private fun getParamOrThrow(param: FunctionParameter, part: FunctionCallPart): T { - return param.type.parse.invoke(part.args[param.name]) - ?: throw RuntimeException("Missing parameter named ${param.name} for function ${part.name}") - } - private fun constructRequest(vararg prompt: Content) = GenerateContentRequest( modelName, diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt index fefe11f8..ada82d3f 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt @@ -28,9 +28,13 @@ package com.google.ai.client.generativeai.type class NoParameterFunction( name: String, description: String, - val function: suspend () -> String, + val function: () -> String, ) : FunctionDeclaration(name, description) { - override fun getParameters() = listOf>() + override fun getParameters() = listOf>() + + operator fun invoke() = function() + + operator fun invoke(part: FunctionCallPart) = invoke() } /** @@ -46,10 +50,15 @@ class NoParameterFunction( class OneParameterFunction( name: String, description: String, - val param: FunctionParameter, - val function: suspend (T) -> String, + val param: ParameterDeclaration, + val function: (T) -> String, ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf(param) + + operator fun invoke(part: FunctionCallPart): String { + val arg1 = part.getArgOrThrow(param) + return function(arg1) + } } /** @@ -66,11 +75,17 @@ class OneParameterFunction( class TwoParameterFunction( name: String, description: String, - val param1: FunctionParameter, - val param2: FunctionParameter, - val function: suspend (T, U) -> String, + val param1: ParameterDeclaration, + val param2: ParameterDeclaration, + val function: (T, U) -> String, ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf(param1, param2) + + operator fun invoke(part: FunctionCallPart): String { + val arg1 = part.getArgOrThrow(param1) + val arg2 = part.getArgOrThrow(param2) + return function(arg1, arg2) + } } /** @@ -88,12 +103,19 @@ class TwoParameterFunction( class ThreeParameterFunction( name: String, description: String, - val param1: FunctionParameter, - val param2: FunctionParameter, - val param3: FunctionParameter, - val function: suspend (T, U, V) -> String, + val param1: ParameterDeclaration, + val param2: ParameterDeclaration, + val param3: ParameterDeclaration, + val function: (T, U, V) -> String, ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf(param1, param2, param3) + + operator fun invoke(part: FunctionCallPart): String { + val arg1 = part.getArgOrThrow(param1) + val arg2 = part.getArgOrThrow(param2) + val arg3 = part.getArgOrThrow(param3) + return function(arg1, arg2, arg3) + } } /** @@ -112,13 +134,21 @@ class ThreeParameterFunction( class FourParameterFunction( name: String, description: String, - val param1: FunctionParameter, - val param2: FunctionParameter, - val param3: FunctionParameter, - val param4: FunctionParameter, - val function: suspend (T, U, V, W) -> String, + val param1: ParameterDeclaration, + val param2: ParameterDeclaration, + val param3: ParameterDeclaration, + val param4: ParameterDeclaration, + val function: (T, U, V, W) -> String, ) : FunctionDeclaration(name, description) { override fun getParameters() = listOf(param1, param2, param3, param4) + + operator fun invoke(part: FunctionCallPart): String { + val arg1 = part.getArgOrThrow(param1) + val arg2 = part.getArgOrThrow(param2) + val arg3 = part.getArgOrThrow(param3) + val arg4 = part.getArgOrThrow(param4) + return function(arg1, arg2, arg3, arg4) + } } @BetaGenAiAPI @@ -126,259 +156,73 @@ abstract class FunctionDeclaration( val name: String, val description: String, ) { - abstract fun getParameters(): List> + abstract fun getParameters(): List> } -/** - * A builder to help build [FunctionDeclaration] objects - * - * @property name The name of the function call, this should be clear and descriptive for the model - * @property description A description of what the function does and its output. - */ -@BetaGenAiAPI -class FunctionBuilder(private val name: String, private val description: String) { - - fun build(function: suspend () -> String): FunctionDeclaration { - return NoParameterFunction(name, description, function) - } - - fun param(param: FunctionParameter): OneFunctionBuilder { - return OneFunctionBuilder(name, description, param) - } - - fun param( - paramName: String, - paramDescription: String, - type: FunctionType - ): OneFunctionBuilder { - return OneFunctionBuilder( - name, - description, - FunctionParameter(paramName, paramDescription, type) - ) - } - - fun stringParam(paramName: String, paramDescription: String): OneFunctionBuilder { - return OneFunctionBuilder( - name, - description, - FunctionParameter(paramName, paramDescription, FunctionType.STRING) - ) - } - - fun intParam(paramName: String, paramDescription: String): OneFunctionBuilder { - return OneFunctionBuilder( - name, - description, - FunctionParameter(paramName, paramDescription, FunctionType.INT) - ) - } - - fun boolParam(paramName: String, paramDescription: String): OneFunctionBuilder { - return OneFunctionBuilder( - name, - description, - FunctionParameter(paramName, paramDescription, FunctionType.BOOLEAN) - ) - } -} - -@BetaGenAiAPI -class OneFunctionBuilder( - private val name: String, - private val description: String, - private val param1: FunctionParameter +class ParameterDeclaration( + val name: String, + val description: String, + private val type: FunctionType ) { - fun build(function: suspend (T) -> String): FunctionDeclaration { - return OneParameterFunction(name, description, param1, function) - } + fun fromString(value: String?) = type.parse(value) - fun param(param: FunctionParameter): TwoFunctionBuilder { - return TwoFunctionBuilder(name, description, param1, param) - } + companion object { + fun int(name: String, description: String) = + ParameterDeclaration(name, description, FunctionType.INT) - fun param( - paramName: String, - paramDescription: String, - type: FunctionType - ): TwoFunctionBuilder { - return TwoFunctionBuilder( - name, - description, - param1, - FunctionParameter(paramName, paramDescription, type) - ) - } + fun string(name: String, description: String) = + ParameterDeclaration(name, description, FunctionType.STRING) - fun stringParam(paramName: String, paramDescription: String): TwoFunctionBuilder { - return TwoFunctionBuilder( - name, - description, - param1, - FunctionParameter(paramName, paramDescription, FunctionType.STRING) - ) - } - - fun intParam(paramName: String, paramDescription: String): TwoFunctionBuilder { - return TwoFunctionBuilder( - name, - description, - param1, - FunctionParameter(paramName, paramDescription, FunctionType.INT) - ) - } - - fun boolParam(paramName: String, paramDescription: String): TwoFunctionBuilder { - return TwoFunctionBuilder( - name, - description, - param1, - FunctionParameter(paramName, paramDescription, FunctionType.BOOLEAN) - ) + fun boolean(name: String, description: String) = + ParameterDeclaration(name, description, FunctionType.BOOLEAN) } } @BetaGenAiAPI -class TwoFunctionBuilder( - private val name: String, - private val description: String, - private val param1: FunctionParameter, - private val param2: FunctionParameter, -) { - fun build(function: suspend (T, U) -> String): FunctionDeclaration { - return TwoParameterFunction(name, description, param1, param2, function) - } - - fun param(param: FunctionParameter): ThreeFunctionBuilder { - return ThreeFunctionBuilder(name, description, param1, param2, param) - } - - fun param( - paramName: String, - paramDescription: String, - type: FunctionType - ): ThreeFunctionBuilder { - return ThreeFunctionBuilder( - name, - description, - param1, - param2, - FunctionParameter(paramName, paramDescription, type) - ) - } - - fun stringParam(paramName: String, paramDescription: String): ThreeFunctionBuilder { - return ThreeFunctionBuilder( - name, - description, - param1, - param2, - FunctionParameter(paramName, paramDescription, FunctionType.STRING) - ) - } - - fun intParam(paramName: String, paramDescription: String): ThreeFunctionBuilder { - return ThreeFunctionBuilder( - name, - description, - param1, - param2, - FunctionParameter(paramName, paramDescription, FunctionType.INT) - ) - } - - fun boolParam(paramName: String, paramDescription: String): ThreeFunctionBuilder { - return ThreeFunctionBuilder( - name, - description, - param1, - param2, - FunctionParameter(paramName, paramDescription, FunctionType.BOOLEAN) - ) - } -} +fun defineFunction(name: String, description: String, function: () -> String) = + NoParameterFunction(name, description, function) @BetaGenAiAPI -class ThreeFunctionBuilder( - private val name: String, - private val description: String, - private val param1: FunctionParameter, - private val param2: FunctionParameter, - private val param3: FunctionParameter, -) { - fun build(function: suspend (T, U, V) -> String): FunctionDeclaration { - return ThreeParameterFunction(name, description, param1, param2, param3, function) - } - - fun param(param: FunctionParameter): FourFunctionBuilder { - return FourFunctionBuilder(name, description, param1, param2, param3, param) - } - - fun param( - paramName: String, - paramDescription: String, - type: FunctionType - ): FourFunctionBuilder { - return FourFunctionBuilder( - name, - description, - param1, - param2, - param3, - FunctionParameter(paramName, paramDescription, type) - ) - } - - fun stringParam( - paramName: String, - paramDescription: String - ): FourFunctionBuilder { - return FourFunctionBuilder( - name, - description, - param1, - param2, - param3, - FunctionParameter(paramName, paramDescription, FunctionType.STRING) - ) - } +fun defineFunction( + name: String, + description: String, + arg1: ParameterDeclaration, + function: (T) -> String +) = OneParameterFunction(name, description, arg1, function) - fun intParam(paramName: String, paramDescription: String): FourFunctionBuilder { - return FourFunctionBuilder( - name, - description, - param1, - param2, - param3, - FunctionParameter(paramName, paramDescription, FunctionType.INT) - ) - } +@BetaGenAiAPI +fun defineFunction( + name: String, + description: String, + arg1: ParameterDeclaration, + arg2: ParameterDeclaration, + function: (T, U) -> String +) = TwoParameterFunction(name, description, arg1, arg2, function) - fun boolParam( - paramName: String, - paramDescription: String - ): FourFunctionBuilder { - return FourFunctionBuilder( - name, - description, - param1, - param2, - param3, - FunctionParameter(paramName, paramDescription, FunctionType.BOOLEAN) - ) - } -} +@BetaGenAiAPI +fun defineFunction( + name: String, + description: String, + arg1: ParameterDeclaration, + arg2: ParameterDeclaration, + arg3: ParameterDeclaration, + function: (T, U, W) -> String +) = ThreeParameterFunction(name, description, arg1, arg2, arg3, function) @BetaGenAiAPI -class FourFunctionBuilder( - private val name: String, - private val description: String, - private val param1: FunctionParameter, - private val param2: FunctionParameter, - private val param3: FunctionParameter, - private val param4: FunctionParameter, -) { - fun build(function: suspend (T, U, V, W) -> String): FunctionDeclaration { - return FourParameterFunction(name, description, param1, param2, param3, param4, function) - } +fun defineFunction( + name: String, + description: String, + arg1: ParameterDeclaration, + arg2: ParameterDeclaration, + arg3: ParameterDeclaration, + arg4: ParameterDeclaration, + function: (T, U, W, Z) -> String +) = FourParameterFunction(name, description, arg1, arg2, arg3, arg4, function) + +private fun FunctionCallPart.getArgOrThrow(param: ParameterDeclaration): T { + return param.fromString(args[param.name]) + ?: throw RuntimeException( + "Missing argument for parameter \"${param.name}\" for function \"$name\"" + ) } diff --git a/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt b/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt index da0e90b6..c3920e9f 100644 --- a/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt +++ b/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt @@ -16,6 +16,11 @@ package com.google.ai.client.generativeai +import com.google.ai.client.generativeai.type.BetaGenAiAPI +import com.google.ai.client.generativeai.type.FunctionCallPart +import com.google.ai.client.generativeai.type.ParameterDeclaration +import com.google.ai.client.generativeai.type.TwoParameterFunction +import com.google.ai.client.generativeai.type.defineFunction import com.google.ai.client.generativeai.util.commonTest import com.google.ai.client.generativeai.util.createResponses import com.google.ai.client.generativeai.util.prepareStreamingResponse @@ -23,7 +28,7 @@ import io.kotest.matchers.shouldBe import io.ktor.utils.io.close import io.ktor.utils.io.writeFully import kotlin.time.Duration.Companion.seconds -import kotlinx.coroutines.flow.collect +import kotlinx.coroutines.runBlocking import kotlinx.coroutines.withTimeout import org.junit.Test @@ -45,4 +50,40 @@ internal class GenerativeModelTests { } } } + + // + // + // FOR DEV PURPOSES ONLY + // + // + + fun myfun(a: Int, b: Int): String { + return (a + b).toString() + } + + @OptIn(BetaGenAiAPI::class) + @Test + fun `calling test`(): Unit = runBlocking { + // val f = + // FunctionBuilder("sum", "add two numbers together") + // .intParam("a", "First number to add together") + // .intParam("b", "Second number to add together") + // .build(::myfun) + + val f = + defineFunction( + "sum", + "add two numbers together", + ParameterDeclaration.int("a", "First number to add together"), + ParameterDeclaration.int("b", "Second number to add together") + ) { a, b -> + (a + b).toString() + } + + val x = f as TwoParameterFunction + val p = FunctionCallPart("sum", mapOf("a" to "2", "b" to "3")) + val q = x(p) + + q shouldBe "5" + } }