Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Controlled generation sample #205

Merged
merged 5 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .changes/generativeai/club-bite-carpenter-country.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"MAJOR","changes":["Improve usability of the Schema type in Java"]}
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,16 @@ import org.json.JSONObject
*/
class FunctionType<T>(val name: String, val parse: (String?) -> T?) {
companion object {
val STRING = FunctionType<String>("STRING") { it }
val INTEGER = FunctionType<Int>("INTEGER") { it?.toIntOrNull() }
val LONG = FunctionType<Long>("INTEGER") { it?.toLongOrNull() }
val NUMBER = FunctionType<Double>("NUMBER") { it?.toDoubleOrNull() }
val BOOLEAN = FunctionType<Boolean>("BOOLEAN") { it?.toBoolean() }
@JvmField val STRING = FunctionType<String>("STRING") { it }
@JvmField val INTEGER = FunctionType<Int>("INTEGER") { it?.toIntOrNull() }
@JvmField val LONG = FunctionType<Long>("INTEGER") { it?.toLongOrNull() }
@JvmField val NUMBER = FunctionType<Double>("NUMBER") { it?.toDoubleOrNull() }
@JvmField val BOOLEAN = FunctionType<Boolean>("BOOLEAN") { it?.toBoolean() }
@JvmField
val ARRAY =
FunctionType<List<String>>("ARRAY") { it ->
it?.let { Json.parseToJsonElement(it).jsonArray.map { element -> element.toString() } }
}
val OBJECT = FunctionType<JSONObject>("OBJECT") { it?.let { JSONObject(it) } }
@JvmField val OBJECT = FunctionType<JSONObject>("OBJECT") { it?.let { JSONObject(it) } }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,73 @@

package com.google.ai.client.generative.samples

import com.google.ai.client.generativeai.GenerativeModel
import com.google.ai.client.generativeai.type.FunctionType
import com.google.ai.client.generativeai.type.Schema
import com.google.ai.client.generativeai.type.generationConfig

// Set up your API Key
// ====================
//
// To use the Gemini API, you'll need an API key. To learn more, see
// the "Set up your API Key section" in the [Gemini API
// quickstart](https://ai.google.dev/gemini-api/docs/quickstart?lang=android#set-up-api-key).

// TODO
suspend fun json_controlled_generation() {
// [START json_controlled_generation]
val generativeModel =
GenerativeModel(
// Specify a Gemini model appropriate for your use case
modelName = "gemini-1.5-pro",
// Access your API key as a Build Configuration variable (see "Set up your API key" above)
apiKey = BuildConfig.apiKey,
generationConfig = generationConfig {
responseMimeType = "application/json"
responseSchema = Schema(
name = "recipes",
description = "List of recipes",
type = FunctionType.ARRAY,
items = Schema(
name = "recipe",
description = "A recipe",
type = FunctionType.OBJECT,
properties = mapOf(
"recipeName" to Schema(
name = "recipeName",
description = "Name of the recipe",
type = FunctionType.STRING,
nullable = false
),
),
required = listOf("recipeName")
),
)
})

val prompt = "List a few popular cookie recipes."
val response = generativeModel.generateContent(prompt)
print(response.text)
// [END json_controlled_generation]
}

suspend fun json_no_schema() {
// [START json_no_schema]
val generativeModel =
GenerativeModel(
// Specify a Gemini model appropriate for your use case
modelName = "gemini-1.5-flash",
// Access your API key as a Build Configuration variable (see "Set up your API key" above)
apiKey = BuildConfig.apiKey,
generationConfig = generationConfig {
responseMimeType = "application/json"
})

val prompt = """
List a few popular cookie recipes using this JSON schema:
Recipe = {'recipeName': string}
Return: Array<Recipe>
""".trimIndent()
val response = generativeModel.generateContent(prompt)
print(response.text)
// [END json_no_schema]
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void codeExecutionBasic() {
/* generationConfig */ null,
/* safetySettings */ null,
/* requestOptions */ new RequestOptions(),
/* tools */ Collections.singletonList(Tool.Companion.getCODE_EXECUTION()));
/* tools */ Collections.singletonList(Tool.CODE_EXECUTION));
GenerativeModelFutures model = GenerativeModelFutures.from(gm);

Content inputContent =
Expand Down Expand Up @@ -91,7 +91,7 @@ void codeExecutionChat() {
/* generationConfig */ null,
/* safetySettings */ null,
/* requestOptions */ new RequestOptions(),
/* tools */ Collections.singletonList(Tool.Companion.getCODE_EXECUTION()));
/* tools */ Collections.singletonList(Tool.CODE_EXECUTION));
GenerativeModelFutures model = GenerativeModelFutures.from(gm);

Content inputContent =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,140 @@
// the "Set up your API Key section" in the [Gemini API
// quickstart](https://ai.google.dev/gemini-api/docs/quickstart?lang=android#set-up-api-key).

import com.google.ai.client.generativeai.GenerativeModel;
import com.google.ai.client.generativeai.java.GenerativeModelFutures;
import com.google.ai.client.generativeai.type.Content;
import com.google.ai.client.generativeai.type.FunctionType;
import com.google.ai.client.generativeai.type.GenerateContentResponse;
import com.google.ai.client.generativeai.type.GenerationConfig;
import com.google.ai.client.generativeai.type.Schema;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;

class ControlledGeneration {
// TODO
void jsonControlledGeneration() {
// [START json_controlled_generation]
Schema<List<String>> schema =
new Schema(
/* name */ "recipes",
/* description */ "List of recipes",
/* format */ null,
/* nullable */ false,
/* list */ null,
/* properties */ null,
/* required */ null,
/* items */ new Schema(
/* name */ "recipe",
/* description */ "A recipe",
/* format */ null,
/* nullable */ false,
/* list */ null,
/* properties */ Map.of(
"recipeName",
new Schema(
/* name */ "recipeName",
/* description */ "Name of the recipe",
/* format */ null,
/* nullable */ false,
/* list */ null,
/* properties */ null,
/* required */ null,
/* items */ null,
/* type */ FunctionType.STRING)),
/* required */ null,
/* items */ null,
/* type */ FunctionType.OBJECT),
/* type */ FunctionType.ARRAY);

GenerationConfig.Builder configBuilder = new GenerationConfig.Builder();
configBuilder.responseMimeType = "application/json";
configBuilder.responseSchema = schema;

GenerationConfig generationConfig = configBuilder.build();

// Specify a Gemini model appropriate for your use case
GenerativeModel gm =
new GenerativeModel(
/* modelName */ "gemini-1.5-pro",
// Access your API key as a Build Configuration variable (see "Set up your API key"
// above)
/* apiKey */ BuildConfig.apiKey,
/* generationConfig */ generationConfig);
GenerativeModelFutures model = GenerativeModelFutures.from(gm);

Content content = new Content.Builder().addText("List a few popular cookie recipes.").build();

// For illustrative purposes only. You should use an executor that fits your needs.
Executor executor = Executors.newSingleThreadExecutor();

ListenableFuture<GenerateContentResponse> response = model.generateContent(content);
Futures.addCallback(
response,
new FutureCallback<GenerateContentResponse>() {
@Override
public void onSuccess(GenerateContentResponse result) {
String resultText = result.getText();
System.out.println(resultText);
}

@Override
public void onFailure(Throwable t) {
t.printStackTrace();
}
},
executor);
// [END json_controlled_generation]
}

void json_no_schema() {
// [START json_no_schema]
GenerationConfig.Builder configBuilder = new GenerationConfig.Builder();
configBuilder.responseMimeType = "application/json";

GenerationConfig generationConfig = configBuilder.build();

// Specify a Gemini model appropriate for your use case
GenerativeModel gm =
new GenerativeModel(
/* modelName */ "gemini-1.5-flash",
// Access your API key as a Build Configuration variable (see "Set up your API key"
// above)
/* apiKey */ BuildConfig.apiKey,
/* generationConfig */ generationConfig);
GenerativeModelFutures model = GenerativeModelFutures.from(gm);

Content content =
new Content.Builder()
.addText(
"List a few popular cookie recipes using this JSON schema:\n"
+ "Recipe = {'recipeName': string}\n"
+ "Return: Array<Recipe>")
.build();

// For illustrative purposes only. You should use an executor that fits your needs.
Executor executor = Executors.newSingleThreadExecutor();

ListenableFuture<GenerateContentResponse> response = model.generateContent(content);
Futures.addCallback(
response,
new FutureCallback<GenerateContentResponse>() {
@Override
public void onSuccess(GenerateContentResponse result) {
String resultText = result.getText();
System.out.println(resultText);
}

@Override
public void onFailure(Throwable t) {
t.printStackTrace();
}
},
executor);
// [END json_no_schema]
}
}
Loading