Skip to content

Commit

Permalink
Add function calling snippets (google-gemini#200)
Browse files Browse the repository at this point in the history
Co-authored-by: Daymon <[email protected]>
  • Loading branch information
2 people authored and PatilShreyas committed Sep 21, 2024
1 parent 7d3ce7d commit b9549a7
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 3 deletions.
1 change: 1 addition & 0 deletions .changes/generativeai/bead-bead-basketball-boat.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"MAJOR","changes":["Improve java support for function calling"]}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ package dev.shreyaspatil.ai.client.generativeai.type
* @param functionDeclarations The set of functions that this tool allows the model access to
* @param codeExecution This is a flag value to enable Code Execution. Use [CODE_EXECUTION].
*/
class Tool(
class Tool
@JvmOverloads
constructor(
val functionDeclarations: List<FunctionDeclaration>? = null,
val codeExecution: JSONObject? = null,
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class Schema<T>(

companion object {
/** Registers a schema for a 32 bit integer number */
@JvmStatic
@JvmName("numInt")
fun int(name: String, description: String) =
Schema<Int>(
name = name,
Expand All @@ -89,6 +91,8 @@ class Schema<T>(
)

/** Registers a schema for a 64 bit integer number */
@JvmStatic
@JvmName("numLong")
fun long(name: String, description: String) =
Schema<Long>(
name = name,
Expand All @@ -98,6 +102,7 @@ class Schema<T>(
)

/** Registers a schema for a string */
@JvmStatic
fun str(name: String, description: String) =
Schema<String>(
name = name,
Expand All @@ -107,6 +112,7 @@ class Schema<T>(
)

/** Registers a schema for a boolean */
@JvmStatic
fun bool(name: String, description: String) =
Schema<Boolean>(
name = name,
Expand All @@ -116,6 +122,8 @@ class Schema<T>(
)

/** Registers a schema for a floating point number */
@JvmStatic
@JvmName("numDouble")
fun double(name: String, description: String) =
Schema<Double>(
name = name,
Expand All @@ -127,6 +135,7 @@ class Schema<T>(
/**
* Registers a schema for a complex object. In a function it will be returned as a [JSONObject]
*/
@JvmStatic
fun obj(name: String, description: String, vararg contents: Schema<out Any>) =
Schema<JSONObject>(
name = name,
Expand All @@ -142,6 +151,7 @@ class Schema<T>(
*
* @param items can be used to specify the type of the array
*/
@JvmStatic
fun arr(name: String, description: String, items: Schema<out Any>? = null) =
Schema<List<String>>(
name = name,
Expand All @@ -152,6 +162,8 @@ class Schema<T>(
)

/** Registers a schema for an enum */
@JvmStatic
@JvmName("enumeration")
fun enum(name: String, description: String, values: List<String>) =
Schema<String>(
name = name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@

package com.google.ai.client.generative.samples

import com.google.ai.client.generativeai.GenerativeModel
import com.google.ai.client.generativeai.type.FunctionResponsePart
import com.google.ai.client.generativeai.type.InvalidStateException
import com.google.ai.client.generativeai.type.Schema
import com.google.ai.client.generativeai.type.Tool
import com.google.ai.client.generativeai.type.content
import com.google.ai.client.generativeai.type.defineFunction
import org.json.JSONObject


// Set up your API Key
// ====================
Expand All @@ -24,4 +33,53 @@ package com.google.ai.client.generative.samples
// the "Set up your API Key section" in the [Gemini API
// quickstart](https://ai.google.dev/gemini-api/docs/quickstart?lang=android#set-up-api-key).

// TODO
suspend fun functionCalling() {
// [START function_calling]
fun multiply(a: Double, b: Double) = a * b

val multiplyDefinition = defineFunction(
name = "multiply",
description = "returns the product of the provided numbers.",
parameters = listOf(
Schema.double("a", "First number"),
Schema.double("b", "Second number")
)
)

val usableFunctions = listOf(multiplyDefinition)

val generativeModel =
GenerativeModel(
// Specify a Gemini model appropriate for your use case
modelName = "gemini-1.5-flash",
// Access your API key as a Build Configuration variable (see "Set up your API key" above)
apiKey = BuildConfig.apiKey,
// List the functions definitions you want to make available to the model
tools = listOf(Tool(usableFunctions))
)

val chat = generativeModel.startChat()
val prompt = "I have 57 cats, each owns 44 mittens, how many mittens is that in total?"

// Send the message to the generative model
var response = chat.sendMessage(prompt)

// Check if the model responded with a function call
response.functionCalls.first { it.name == "multiply" }.apply {
val a: String by args
val b: String by args

val result = JSONObject(mapOf("result" to multiply(a.toDouble(), b.toDouble())))
response = chat.sendMessage(
content(role = "function") {
part(FunctionResponsePart("multiply", result))
}
)
}

// Whenever the model responds with text, show it in the UI
response.text?.let { modelResponse ->
println(modelResponse)
}
// [END function_calling]
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,28 @@

package com.google.ai.client.generative.samples.java;

import static com.google.ai.client.generativeai.type.FunctionDeclarationsKt.defineFunction;

import com.google.ai.client.generativeai.GenerativeModel;
import com.google.ai.client.generativeai.java.ChatFutures;
import com.google.ai.client.generativeai.java.GenerativeModelFutures;
import com.google.ai.client.generativeai.type.Content;
import com.google.ai.client.generativeai.type.FunctionCallPart;
import com.google.ai.client.generativeai.type.FunctionDeclaration;
import com.google.ai.client.generativeai.type.FunctionResponsePart;
import com.google.ai.client.generativeai.type.GenerateContentResponse;
import com.google.ai.client.generativeai.type.RequestOptions;
import com.google.ai.client.generativeai.type.Schema;
import com.google.ai.client.generativeai.type.Tool;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import java.util.Arrays;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import org.json.JSONException;
import org.json.JSONObject;

// Set up your API Key
// ====================
//
Expand All @@ -22,5 +44,100 @@
// quickstart](https://ai.google.dev/gemini-api/docs/quickstart?lang=android#set-up-api-key).

class FunctionCalling {
// TODO

double multiply(double a, double b) {
return a * b;
}

void functionCalling() {
// [START function_calling]
FunctionDeclaration multiplyDefinition =
defineFunction(
/* name */ "multiply",
/* description */ "returns a * b.",
/* parameters */ Arrays.asList(
Schema.numDouble("a", "First parameter"),
Schema.numDouble("b", "Second parameter")),
/* required */ Arrays.asList("a", "b"));

Tool tool = new Tool(Arrays.asList(multiplyDefinition), null);

// Specify a Gemini model appropriate for your use case
GenerativeModel gm =
new GenerativeModel(
/* modelName */ "gemini-1.5-flash",
// Access your API key as a Build Configuration variable (see "Set up your API key"
// above)
/* apiKey */ BuildConfig.apiKey,
/* generationConfig (optional) */ null,
/* safetySettings (optional) */ null,
/* requestOptions (optional) */ new RequestOptions(),
/* functionDeclarations (optional) */ Arrays.asList(tool));
GenerativeModelFutures model = GenerativeModelFutures.from(gm);

// Create prompt
Content.Builder userContentBuilder = new Content.Builder();
userContentBuilder.setRole("user");
userContentBuilder.addText(
"I have 57 cats, each owns 44 mittens, how many mittens is that in total?");
Content userMessage = userContentBuilder.build();

// For illustrative purposes only. You should use an executor that fits your needs.
Executor executor = Executors.newSingleThreadExecutor();

// Initialize the chat
ChatFutures chat = model.startChat();

// Send the message
ListenableFuture<GenerateContentResponse> response = chat.sendMessage(userMessage);

Futures.addCallback(
response,
new FutureCallback<GenerateContentResponse>() {
@Override
public void onSuccess(GenerateContentResponse result) {
if (!result.getFunctionCalls().isEmpty()) {
handleFunctionCall(result);
}
if (!result.getText().isEmpty()) {
System.out.println(result.getText());
}
}

@Override
public void onFailure(Throwable t) {
t.printStackTrace();
}

private void handleFunctionCall(GenerateContentResponse result) {
FunctionCallPart multiplyFunctionCallPart =
result.getFunctionCalls().stream()
.filter(fun -> fun.getName().equals("multiply"))
.findFirst()
.get();
double a = Double.parseDouble(multiplyFunctionCallPart.getArgs().get("a"));
double b = Double.parseDouble(multiplyFunctionCallPart.getArgs().get("b"));

try {
// `multiply(a, b)` is a regular java function defined in another class
FunctionResponsePart functionResponsePart =
new FunctionResponsePart(
"multiply", new JSONObject().put("result", multiply(a, b)));

// Create prompt
Content.Builder functionCallResponse = new Content.Builder();
userContentBuilder.setRole("user");
userContentBuilder.addPart(functionResponsePart);
Content userMessage = userContentBuilder.build();

chat.sendMessage(userMessage);
} catch (JSONException e) {
throw new RuntimeException(e);
}
}
},
executor);

// [END function_calling]
}
}

0 comments on commit b9549a7

Please sign in to comment.