Skip to content

Commit

Permalink
Alternative implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
rlazo committed Feb 10, 2024
1 parent 1136610 commit b9ce8c8
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 310 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ import com.google.ai.client.generativeai.type.CountTokensResponse
import com.google.ai.client.generativeai.type.FinishReason
import com.google.ai.client.generativeai.type.FourParameterFunction
import com.google.ai.client.generativeai.type.FunctionCallPart
import com.google.ai.client.generativeai.type.FunctionParameter
import com.google.ai.client.generativeai.type.GenerateContentResponse
import com.google.ai.client.generativeai.type.GenerationConfig
import com.google.ai.client.generativeai.type.GoogleGenerativeAIException
import com.google.ai.client.generativeai.type.NoParameterFunction
import com.google.ai.client.generativeai.type.OneParameterFunction
import com.google.ai.client.generativeai.type.ParameterDeclaration
import com.google.ai.client.generativeai.type.PromptBlockedException
import com.google.ai.client.generativeai.type.ResponseStoppedException
import com.google.ai.client.generativeai.type.SafetySetting
Expand Down Expand Up @@ -182,71 +182,34 @@ internal constructor(
/**
* Executes a function request by the model.
*
* @param call A [FunctionCallPart] from the model, containing a function call and parameters
* @param functionCallPart A [FunctionCallPart] from the model, containing a function call and
* parameters
* @return The output of the requested function call
*/
@BetaGenAiAPI
suspend fun executeFunction(call: FunctionCallPart): String {
fun executeFunction(functionCallPart: FunctionCallPart): String {
if (tools == null) {
throw RuntimeException("No registered tools")
}
val tool = tools.first { it.functionDeclarations.any { it.name == call.name } }
val declaration =
tool.functionDeclarations.firstOrNull() { it.name == call.name }
?: throw RuntimeException("No registered function named ${call.name}")
return when (declaration) {
is NoParameterFunction -> {
declaration.function.invoke()
}
is OneParameterFunction<*> -> {
declaration
.let { declaration as OneParameterFunction<Any?> }
.let {
val param1 = getParamOrThrow(it.param, call)
it.function.invoke(param1)
}
}
is TwoParameterFunction<*, *> -> {
declaration
.let { declaration as TwoParameterFunction<Any?, Any?> }
.let {
val param1 = getParamOrThrow(it.param1, call)
val param2 = getParamOrThrow(it.param2, call)
it.function.invoke(param1, param2)
}
}
is ThreeParameterFunction<*, *, *> -> {
declaration
.let { declaration as ThreeParameterFunction<Any?, Any?, Any?> }
.let {
val param1 = getParamOrThrow(it.param1, call)
val param2 = getParamOrThrow(it.param2, call)
val param3 = getParamOrThrow(it.param3, call)
it.function.invoke(param1, param2, param3)
}
}
is FourParameterFunction<*, *, *, *> -> {
declaration
.let { declaration as FourParameterFunction<Any?, Any?, Any?, Any?> }
.let {
val param1 = getParamOrThrow(it.param1, call)
val param2 = getParamOrThrow(it.param2, call)
val param3 = getParamOrThrow(it.param3, call)
val param4 = getParamOrThrow(it.param4, call)
it.function.invoke(param1, param2, param3, param4)
}
}
val tool = tools.first { it.functionDeclarations.any { it.name == functionCallPart.name } }
val callable =
tool.functionDeclarations.firstOrNull() { it.name == functionCallPart.name }
?: throw RuntimeException("No registered function named ${functionCallPart.name}")
return when (callable) {
is NoParameterFunction -> callable()
is OneParameterFunction<*> -> (callable as OneParameterFunction<Any?>)(functionCallPart)
is TwoParameterFunction<*, *> ->
(callable as TwoParameterFunction<Any?, Any?>)(functionCallPart)
is ThreeParameterFunction<*, *, *> ->
(callable as ThreeParameterFunction<Any?, Any?, Any?>)(functionCallPart)
is FourParameterFunction<*, *, *, *> ->
(callable as FourParameterFunction<Any?, Any?, Any?, Any?>)(functionCallPart)
else -> {
throw RuntimeException("UNREACHABLE")
}
}
}

private fun <T> getParamOrThrow(param: FunctionParameter<T>, part: FunctionCallPart): T {
return param.type.parse.invoke(part.args[param.name])
?: throw RuntimeException("Missing parameter named ${param.name} for function ${part.name}")
}

private fun constructRequest(vararg prompt: Content) =
GenerateContentRequest(
modelName,
Expand Down
Loading

0 comments on commit b9ce8c8

Please sign in to comment.