Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Actual function calling in java sample
Browse files Browse the repository at this point in the history
rlazo committed Jul 13, 2024
1 parent b1855ed commit 9e673c0
Showing 2 changed files with 90 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -25,7 +25,7 @@ import org.json.JSONObject
* @param functionDeclarations The set of functions that this tool allows the model access to
* @param codeExecution This is a flag value to enable Code Execution. Use [CODE_EXECUTION].
*/
class Tool(
class Tool @JvmOverloads constructor(
val functionDeclarations: List<FunctionDeclaration>? = null,
val codeExecution: JSONObject? = null,
) {
Original file line number Diff line number Diff line change
@@ -14,24 +14,41 @@

package com.google.ai.client.generative.samples.java;

// Set up your API Key
// ====================
//
// To use the Gemini API, you'll need an API key. To learn more, see
// the "Set up your API Key section" in the [Gemini API
// quickstart](https://ai.google.dev/gemini-api/docs/quickstart?lang=android#set-up-api-key).

import static com.google.ai.client.generativeai.type.FunctionDeclarationsKt.defineFunction;

import com.google.ai.client.generativeai.GenerativeModel;
import com.google.ai.client.generativeai.java.ChatFutures;
import com.google.ai.client.generativeai.java.GenerativeModelFutures;
import com.google.ai.client.generativeai.type.Content;
import com.google.ai.client.generativeai.type.FunctionCallPart;
import com.google.ai.client.generativeai.type.FunctionDeclaration;
import com.google.ai.client.generativeai.type.FunctionResponsePart;
import com.google.ai.client.generativeai.type.GenerateContentResponse;
import com.google.ai.client.generativeai.type.RequestOptions;
import com.google.ai.client.generativeai.type.Schema;
import com.google.ai.client.generativeai.type.Tool;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import java.util.Arrays;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import org.json.JSONException;
import org.json.JSONObject;

// Set up your API Key
// ====================
//
// To use the Gemini API, you'll need an API key. To learn more, see
// the "Set up your API Key section" in the [Gemini API
// quickstart](https://ai.google.dev/gemini-api/docs/quickstart?lang=android#set-up-api-key).

class FunctionCalling {

double multiply(double a, double b) {
return a * b;
}

void functionCalling() {
// [START function_calling]
FunctionDeclaration multiplyDefinition =
@@ -43,27 +60,83 @@ void functionCalling() {
Schema.numDouble("b", "Second parameter")),
/* required */ Arrays.asList("a", "b"));

Tool tool = new Tool(Arrays.asList(multiplyDefinition), null);

// Specify a Gemini model appropriate for your use case
GenerativeModel gm =
new GenerativeModel(
/* modelName */ "gemini-1.5-flash",
// Access your API key as a Build Configuration variable (see "Set up your API key"
// above)
/* apiKey */ BuildConfig.apiKey);
/* apiKey */ BuildConfig.apiKey,
/* generationConfig (optional) */ null,
/* safetySettings (optional) */ null,
/* requestOptions (optional) */ new RequestOptions(),
/* functionDeclarations (optional) */ Arrays.asList(tool));
GenerativeModelFutures model = GenerativeModelFutures.from(gm);

// (optional) Create previous chat history for context
// Create prompt
Content.Builder userContentBuilder = new Content.Builder();
userContentBuilder.setRole("user");
userContentBuilder.addText("Hello, I have 2 dogs in my house.");
Content userContent = userContentBuilder.build();
userContentBuilder.addText(
"I have 57 cats, each owns 44 mittens, how many mittens is that in total?");
Content userMessage = userContentBuilder.build();

// For illustrative purposes only. You should use an executor that fits your needs.
Executor executor = Executors.newSingleThreadExecutor();

// Initialize the chat
ChatFutures chat = model.startChat();

// Send the message
ListenableFuture<GenerateContentResponse> response = chat.sendMessage(userMessage);

Futures.addCallback(
response,
new FutureCallback<GenerateContentResponse>() {
@Override
public void onSuccess(GenerateContentResponse result) {
if (!result.getFunctionCalls().isEmpty()) {
handleFunctionCall(result);
}
if (!result.getText().isEmpty()) {
System.out.println(result.getText());
}
}

@Override
public void onFailure(Throwable t) {
t.printStackTrace();
}

private void handleFunctionCall(GenerateContentResponse result) {
FunctionCallPart multiplyFunctionCallPart =
result.getFunctionCalls().stream()
.filter(fun -> fun.getName().equals("multiply"))
.findFirst()
.get();
double a = Double.parseDouble(multiplyFunctionCallPart.getArgs().get("a"));
double b = Double.parseDouble(multiplyFunctionCallPart.getArgs().get("b"));

try {
// `multiply(a, b)` is a regular java function defined in another class
FunctionResponsePart functionResponsePart =
new FunctionResponsePart(
"multiply", new JSONObject().put("result", multiply(a, b)));

Content.Builder modelContentBuilder = new Content.Builder();
modelContentBuilder.setRole("model");
modelContentBuilder.addText("Great to meet you. What would you like to know?");
Content modelContent = userContentBuilder.build();
// Create prompt
Content.Builder functionCallResponse = new Content.Builder();
userContentBuilder.setRole("user");
userContentBuilder.addPart(functionResponsePart);
Content userMessage = userContentBuilder.build();

// List<Content> history = Arrays.asList(userContent, modelContent);
chat.sendMessage(userMessage);
} catch (JSONException e) {
throw new RuntimeException(e);
}
}
},
executor);

// [END function_calling]
}

0 comments on commit 9e673c0

Please sign in to comment.