diff --git a/.changes/generativeai/bead-bead-basketball-boat.json b/.changes/generativeai/bead-bead-basketball-boat.json new file mode 100644 index 00000000..0a435a03 --- /dev/null +++ b/.changes/generativeai/bead-bead-basketball-boat.json @@ -0,0 +1 @@ +{"type":"MAJOR","changes":["Improve java support for function calling"]} diff --git a/generativeai/src/commonMain/kotlin/dev/shreyaspatil/ai/client/generativeai/type/Tool.kt b/generativeai/src/commonMain/kotlin/dev/shreyaspatil/ai/client/generativeai/type/Tool.kt index dd6d3c13..254f9c20 100644 --- a/generativeai/src/commonMain/kotlin/dev/shreyaspatil/ai/client/generativeai/type/Tool.kt +++ b/generativeai/src/commonMain/kotlin/dev/shreyaspatil/ai/client/generativeai/type/Tool.kt @@ -22,7 +22,9 @@ package dev.shreyaspatil.ai.client.generativeai.type * @param functionDeclarations The set of functions that this tool allows the model access to * @param codeExecution This is a flag value to enable Code Execution. Use [CODE_EXECUTION]. */ -class Tool( +class Tool +@JvmOverloads +constructor( val functionDeclarations: List? = null, val codeExecution: JSONObject? = null, ) { diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt index 88247f3d..347b0d20 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt @@ -79,6 +79,8 @@ class Schema( companion object { /** Registers a schema for a 32 bit integer number */ + @JvmStatic + @JvmName("numInt") fun int(name: String, description: String) = Schema( name = name, @@ -89,6 +91,8 @@ class Schema( ) /** Registers a schema for a 64 bit integer number */ + @JvmStatic + @JvmName("numLong") fun long(name: String, description: String) = Schema( name = name, @@ -98,6 +102,7 @@ class Schema( ) /** Registers a schema for a string */ + @JvmStatic fun str(name: String, description: String) = Schema( name = name, @@ -107,6 +112,7 @@ class Schema( ) /** Registers a schema for a boolean */ + @JvmStatic fun bool(name: String, description: String) = Schema( name = name, @@ -116,6 +122,8 @@ class Schema( ) /** Registers a schema for a floating point number */ + @JvmStatic + @JvmName("numDouble") fun double(name: String, description: String) = Schema( name = name, @@ -127,6 +135,7 @@ class Schema( /** * Registers a schema for a complex object. In a function it will be returned as a [JSONObject] */ + @JvmStatic fun obj(name: String, description: String, vararg contents: Schema) = Schema( name = name, @@ -142,6 +151,7 @@ class Schema( * * @param items can be used to specify the type of the array */ + @JvmStatic fun arr(name: String, description: String, items: Schema? = null) = Schema>( name = name, @@ -152,6 +162,8 @@ class Schema( ) /** Registers a schema for an enum */ + @JvmStatic + @JvmName("enumeration") fun enum(name: String, description: String, values: List) = Schema( name = name, diff --git a/samples/src/main/java/com/google/ai/client/generative/samples/function_calling.kt b/samples/src/main/java/com/google/ai/client/generative/samples/function_calling.kt index 894336f9..246263a3 100644 --- a/samples/src/main/java/com/google/ai/client/generative/samples/function_calling.kt +++ b/samples/src/main/java/com/google/ai/client/generative/samples/function_calling.kt @@ -16,6 +16,15 @@ package com.google.ai.client.generative.samples +import com.google.ai.client.generativeai.GenerativeModel +import com.google.ai.client.generativeai.type.FunctionResponsePart +import com.google.ai.client.generativeai.type.InvalidStateException +import com.google.ai.client.generativeai.type.Schema +import com.google.ai.client.generativeai.type.Tool +import com.google.ai.client.generativeai.type.content +import com.google.ai.client.generativeai.type.defineFunction +import org.json.JSONObject + // Set up your API Key // ==================== @@ -24,4 +33,53 @@ package com.google.ai.client.generative.samples // the "Set up your API Key section" in the [Gemini API // quickstart](https://ai.google.dev/gemini-api/docs/quickstart?lang=android#set-up-api-key). -// TODO +suspend fun functionCalling() { + // [START function_calling] + fun multiply(a: Double, b: Double) = a * b + + val multiplyDefinition = defineFunction( + name = "multiply", + description = "returns the product of the provided numbers.", + parameters = listOf( + Schema.double("a", "First number"), + Schema.double("b", "Second number") + ) + ) + + val usableFunctions = listOf(multiplyDefinition) + + val generativeModel = + GenerativeModel( + // Specify a Gemini model appropriate for your use case + modelName = "gemini-1.5-flash", + // Access your API key as a Build Configuration variable (see "Set up your API key" above) + apiKey = BuildConfig.apiKey, + // List the functions definitions you want to make available to the model + tools = listOf(Tool(usableFunctions)) + ) + + val chat = generativeModel.startChat() + val prompt = "I have 57 cats, each owns 44 mittens, how many mittens is that in total?" + + // Send the message to the generative model + var response = chat.sendMessage(prompt) + + // Check if the model responded with a function call + response.functionCalls.first { it.name == "multiply" }.apply { + val a: String by args + val b: String by args + + val result = JSONObject(mapOf("result" to multiply(a.toDouble(), b.toDouble()))) + response = chat.sendMessage( + content(role = "function") { + part(FunctionResponsePart("multiply", result)) + } + ) + } + + // Whenever the model responds with text, show it in the UI + response.text?.let { modelResponse -> + println(modelResponse) + } + // [END function_calling] +} diff --git a/samples/src/main/java/com/google/ai/client/generative/samples/java/function_calling.java b/samples/src/main/java/com/google/ai/client/generative/samples/java/function_calling.java index e233e385..9e31e0d8 100644 --- a/samples/src/main/java/com/google/ai/client/generative/samples/java/function_calling.java +++ b/samples/src/main/java/com/google/ai/client/generative/samples/java/function_calling.java @@ -14,6 +14,28 @@ package com.google.ai.client.generative.samples.java; +import static com.google.ai.client.generativeai.type.FunctionDeclarationsKt.defineFunction; + +import com.google.ai.client.generativeai.GenerativeModel; +import com.google.ai.client.generativeai.java.ChatFutures; +import com.google.ai.client.generativeai.java.GenerativeModelFutures; +import com.google.ai.client.generativeai.type.Content; +import com.google.ai.client.generativeai.type.FunctionCallPart; +import com.google.ai.client.generativeai.type.FunctionDeclaration; +import com.google.ai.client.generativeai.type.FunctionResponsePart; +import com.google.ai.client.generativeai.type.GenerateContentResponse; +import com.google.ai.client.generativeai.type.RequestOptions; +import com.google.ai.client.generativeai.type.Schema; +import com.google.ai.client.generativeai.type.Tool; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import java.util.Arrays; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import org.json.JSONException; +import org.json.JSONObject; + // Set up your API Key // ==================== // @@ -22,5 +44,100 @@ // quickstart](https://ai.google.dev/gemini-api/docs/quickstart?lang=android#set-up-api-key). class FunctionCalling { - // TODO + + double multiply(double a, double b) { + return a * b; + } + + void functionCalling() { + // [START function_calling] + FunctionDeclaration multiplyDefinition = + defineFunction( + /* name */ "multiply", + /* description */ "returns a * b.", + /* parameters */ Arrays.asList( + Schema.numDouble("a", "First parameter"), + Schema.numDouble("b", "Second parameter")), + /* required */ Arrays.asList("a", "b")); + + Tool tool = new Tool(Arrays.asList(multiplyDefinition), null); + + // Specify a Gemini model appropriate for your use case + GenerativeModel gm = + new GenerativeModel( + /* modelName */ "gemini-1.5-flash", + // Access your API key as a Build Configuration variable (see "Set up your API key" + // above) + /* apiKey */ BuildConfig.apiKey, + /* generationConfig (optional) */ null, + /* safetySettings (optional) */ null, + /* requestOptions (optional) */ new RequestOptions(), + /* functionDeclarations (optional) */ Arrays.asList(tool)); + GenerativeModelFutures model = GenerativeModelFutures.from(gm); + + // Create prompt + Content.Builder userContentBuilder = new Content.Builder(); + userContentBuilder.setRole("user"); + userContentBuilder.addText( + "I have 57 cats, each owns 44 mittens, how many mittens is that in total?"); + Content userMessage = userContentBuilder.build(); + + // For illustrative purposes only. You should use an executor that fits your needs. + Executor executor = Executors.newSingleThreadExecutor(); + + // Initialize the chat + ChatFutures chat = model.startChat(); + + // Send the message + ListenableFuture response = chat.sendMessage(userMessage); + + Futures.addCallback( + response, + new FutureCallback() { + @Override + public void onSuccess(GenerateContentResponse result) { + if (!result.getFunctionCalls().isEmpty()) { + handleFunctionCall(result); + } + if (!result.getText().isEmpty()) { + System.out.println(result.getText()); + } + } + + @Override + public void onFailure(Throwable t) { + t.printStackTrace(); + } + + private void handleFunctionCall(GenerateContentResponse result) { + FunctionCallPart multiplyFunctionCallPart = + result.getFunctionCalls().stream() + .filter(fun -> fun.getName().equals("multiply")) + .findFirst() + .get(); + double a = Double.parseDouble(multiplyFunctionCallPart.getArgs().get("a")); + double b = Double.parseDouble(multiplyFunctionCallPart.getArgs().get("b")); + + try { + // `multiply(a, b)` is a regular java function defined in another class + FunctionResponsePart functionResponsePart = + new FunctionResponsePart( + "multiply", new JSONObject().put("result", multiply(a, b))); + + // Create prompt + Content.Builder functionCallResponse = new Content.Builder(); + userContentBuilder.setRole("user"); + userContentBuilder.addPart(functionResponsePart); + Content userMessage = userContentBuilder.build(); + + chat.sendMessage(userMessage); + } catch (JSONException e) { + throw new RuntimeException(e); + } + } + }, + executor); + + // [END function_calling] + } }