Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function calling snippets #200

Merged
merged 9 commits into from
Jul 15, 2024
Merged
1 change: 1 addition & 0 deletions .changes/generativeai/bead-bead-basketball-boat.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"MAJOR","changes":["Improve java support for function calling"]}
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class Schema<T>(

companion object {
/** Registers a schema for a 32 bit integer number */
@JvmStatic
@JvmName("numInt")
daymxn marked this conversation as resolved.
Show resolved Hide resolved
fun int(name: String, description: String) =
Schema<Int>(
name = name,
Expand All @@ -89,6 +91,8 @@ class Schema<T>(
)

/** Registers a schema for a 64 bit integer number */
@JvmStatic
@JvmName("numLong")
fun long(name: String, description: String) =
Schema<Long>(
name = name,
Expand All @@ -98,6 +102,7 @@ class Schema<T>(
)

/** Registers a schema for a string */
@JvmStatic
fun str(name: String, description: String) =
Schema<String>(
name = name,
Expand All @@ -107,6 +112,7 @@ class Schema<T>(
)

/** Registers a schema for a boolean */
@JvmStatic
fun bool(name: String, description: String) =
Schema<Boolean>(
name = name,
Expand All @@ -116,6 +122,8 @@ class Schema<T>(
)

/** Registers a schema for a floating point number */
@JvmStatic
@JvmName("numDouble")
fun double(name: String, description: String) =
Schema<Double>(
name = name,
Expand All @@ -127,6 +135,7 @@ class Schema<T>(
/**
* Registers a schema for a complex object. In a function it will be returned as a [JSONObject]
*/
@JvmStatic
fun obj(name: String, description: String, vararg contents: Schema<out Any>) =
Schema<JSONObject>(
name = name,
Expand All @@ -142,6 +151,7 @@ class Schema<T>(
*
* @param items can be used to specify the type of the array
*/
@JvmStatic
fun arr(name: String, description: String, items: Schema<out Any>? = null) =
Schema<List<String>>(
name = name,
Expand All @@ -152,6 +162,8 @@ class Schema<T>(
)

/** Registers a schema for an enum */
@JvmStatic
@JvmName("enumeration")
fun enum(name: String, description: String, values: List<String>) =
Schema<String>(
name = name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ import org.json.JSONObject
* @param functionDeclarations The set of functions that this tool allows the model access to
* @param codeExecution This is a flag value to enable Code Execution. Use [CODE_EXECUTION].
*/
class Tool(
class Tool
@JvmOverloads
constructor(
val functionDeclarations: List<FunctionDeclaration>? = null,
val codeExecution: JSONObject? = null,
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@

package com.google.ai.client.generative.samples

import com.google.ai.client.generativeai.GenerativeModel
import com.google.ai.client.generativeai.type.FunctionResponsePart
import com.google.ai.client.generativeai.type.InvalidStateException
import com.google.ai.client.generativeai.type.Schema
import com.google.ai.client.generativeai.type.Tool
import com.google.ai.client.generativeai.type.content
import com.google.ai.client.generativeai.type.defineFunction
import org.json.JSONObject


// Set up your API Key
// ====================
Expand All @@ -24,4 +33,53 @@ package com.google.ai.client.generative.samples
// the "Set up your API Key section" in the [Gemini API
// quickstart](https://ai.google.dev/gemini-api/docs/quickstart?lang=android#set-up-api-key).

// TODO
suspend fun functionCalling() {
// [START function_calling]
fun multiply(a: Double, b: Double) = a * b

val multiplyDefinition = defineFunction(
name = "multiply",
description = "returns the product of the provided numbers.",
parameters = listOf(
Schema.double("a", "First number"),
Schema.double("b", "Second number")
)
)

val usableFunctions = listOf(multiplyDefinition)

val generativeModel =
GenerativeModel(
// Specify a Gemini model appropriate for your use case
modelName = "gemini-1.5-flash",
// Access your API key as a Build Configuration variable (see "Set up your API key" above)
apiKey = BuildConfig.apiKey,
// List the functions definitions you want to make available to the model
tools = listOf(Tool(usableFunctions))
)

val chat = generativeModel.startChat()
val prompt = "I have 57 cats, each owns 44 mittens, how many mittens is that in total?"

// Send the message to the generative model
var response = chat.sendMessage(prompt)

// Check if the model responded with a function call
response.functionCalls.first { it.name == "multiply" }.apply {
rlazo marked this conversation as resolved.
Show resolved Hide resolved
val a: String by args
val b: String by args

val result = JSONObject(mapOf("result" to multiply(a.toDouble(), b.toDouble())))
response = chat.sendMessage(
content(role = "function") {
part(FunctionResponsePart("multiply", result))
}
)
}

// Whenever the model responds with text, show it in the UI
response.text?.let { modelResponse ->
println(modelResponse)
}
// [END function_calling]
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,28 @@

package com.google.ai.client.generative.samples.java;

import static com.google.ai.client.generativeai.type.FunctionDeclarationsKt.defineFunction;

import com.google.ai.client.generativeai.GenerativeModel;
import com.google.ai.client.generativeai.java.ChatFutures;
import com.google.ai.client.generativeai.java.GenerativeModelFutures;
import com.google.ai.client.generativeai.type.Content;
import com.google.ai.client.generativeai.type.FunctionCallPart;
import com.google.ai.client.generativeai.type.FunctionDeclaration;
import com.google.ai.client.generativeai.type.FunctionResponsePart;
import com.google.ai.client.generativeai.type.GenerateContentResponse;
import com.google.ai.client.generativeai.type.RequestOptions;
import com.google.ai.client.generativeai.type.Schema;
import com.google.ai.client.generativeai.type.Tool;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import java.util.Arrays;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import org.json.JSONException;
import org.json.JSONObject;

// Set up your API Key
// ====================
//
Expand All @@ -22,5 +44,100 @@
// quickstart](https://ai.google.dev/gemini-api/docs/quickstart?lang=android#set-up-api-key).

class FunctionCalling {
// TODO

double multiply(double a, double b) {
return a * b;
}

void functionCalling() {
// [START function_calling]
FunctionDeclaration multiplyDefinition =
defineFunction(
/* name */ "multiply",
/* description */ "returns a * b.",
/* parameters */ Arrays.asList(
Schema.numDouble("a", "First parameter"),
Schema.numDouble("b", "Second parameter")),
/* required */ Arrays.asList("a", "b"));

Tool tool = new Tool(Arrays.asList(multiplyDefinition), null);

// Specify a Gemini model appropriate for your use case
GenerativeModel gm =
new GenerativeModel(
/* modelName */ "gemini-1.5-flash",
// Access your API key as a Build Configuration variable (see "Set up your API key"
// above)
/* apiKey */ BuildConfig.apiKey,
/* generationConfig (optional) */ null,
/* safetySettings (optional) */ null,
/* requestOptions (optional) */ new RequestOptions(),
/* functionDeclarations (optional) */ Arrays.asList(tool));
GenerativeModelFutures model = GenerativeModelFutures.from(gm);

// Create prompt
Content.Builder userContentBuilder = new Content.Builder();
userContentBuilder.setRole("user");
userContentBuilder.addText(
"I have 57 cats, each owns 44 mittens, how many mittens is that in total?");
Content userMessage = userContentBuilder.build();

// For illustrative purposes only. You should use an executor that fits your needs.
Executor executor = Executors.newSingleThreadExecutor();

// Initialize the chat
ChatFutures chat = model.startChat();

// Send the message
ListenableFuture<GenerateContentResponse> response = chat.sendMessage(userMessage);

Futures.addCallback(
response,
new FutureCallback<GenerateContentResponse>() {
@Override
public void onSuccess(GenerateContentResponse result) {
if (!result.getFunctionCalls().isEmpty()) {
handleFunctionCall(result);
}
if (!result.getText().isEmpty()) {
System.out.println(result.getText());
}
}

@Override
public void onFailure(Throwable t) {
t.printStackTrace();
}

private void handleFunctionCall(GenerateContentResponse result) {
FunctionCallPart multiplyFunctionCallPart =
result.getFunctionCalls().stream()
.filter(fun -> fun.getName().equals("multiply"))
.findFirst()
.get();
double a = Double.parseDouble(multiplyFunctionCallPart.getArgs().get("a"));
double b = Double.parseDouble(multiplyFunctionCallPart.getArgs().get("b"));

try {
// `multiply(a, b)` is a regular java function defined in another class
FunctionResponsePart functionResponsePart =
new FunctionResponsePart(
"multiply", new JSONObject().put("result", multiply(a, b)));

// Create prompt
Content.Builder functionCallResponse = new Content.Builder();
userContentBuilder.setRole("user");
userContentBuilder.addPart(functionResponsePart);
Content userMessage = userContentBuilder.build();

chat.sendMessage(userMessage);
} catch (JSONException e) {
throw new RuntimeException(e);
}
}
},
executor);

// [END function_calling]
}
}
Loading