Skip to content

Commit

Permalink
Include full GenerateContentRequest for countTokens (#169)
Browse files Browse the repository at this point in the history
This will make the counted tokens aligned with what the model will count
if the same request were made to `generateContent`.

Extract the JSON conversion to a separate method and call it from
`generateContent`, `generateContentStream`, and `countTokens`.

Add `model` to the general `GenerateContentRequest`. It is required
when using that pattern for counting tokens, and harmless otherwise.
  • Loading branch information
natebosch authored May 22, 2024
1 parent 42a5866 commit 1f2b428
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 49 deletions.
5 changes: 5 additions & 0 deletions pkgs/google_generative_ai/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
replies with more than one part.
- Fix handling of `format` argument to `Schema.number` and `Schema.integer`.
- Export `UsageMetadata`.
- Include the full `GenerateContentRequest` (previously omitted
`safetySettings`, `generationConfig`, `tools`, `toolConfig`, and
`systemInstruction`) in `countTokens` requests. This aligns the token count
with the token count the backend will see in practice for a
`generateContent` request.

## 0.4.0

Expand Down
104 changes: 64 additions & 40 deletions pkgs/google_generative_ai/lib/src/model.dart
Original file line number Diff line number Diff line change
Expand Up @@ -168,23 +168,15 @@ final class GenerativeModel {
List<Tool>? tools,
ToolConfig? toolConfig,
}) async {
safetySettings ??= _safetySettings;
generationConfig ??= _generationConfig;
tools ??= _tools;
toolConfig ??= _toolConfig;
final parameters = {
'contents': prompt.map((p) => p.toJson()).toList(),
if (safetySettings.isNotEmpty)
'safetySettings': safetySettings.map((s) => s.toJson()).toList(),
if (generationConfig != null)
'generationConfig': generationConfig.toJson(),
if (tools != null) 'tools': tools.map((t) => t.toJson()).toList(),
if (toolConfig != null) 'toolConfig': toolConfig.toJson(),
if (_systemInstruction case final systemInstruction?)
'systemInstruction': systemInstruction.toJson(),
};
final response =
await _client.makeRequest(_taskUri(Task.generateContent), parameters);
final response = await _client.makeRequest(
_taskUri(Task.generateContent),
_generateContentRequest(
prompt,
safetySettings: safetySettings,
generationConfig: generationConfig,
tools: tools,
toolConfig: toolConfig,
));
return parseGenerateContentResponse(response);
}

Expand Down Expand Up @@ -212,23 +204,15 @@ final class GenerativeModel {
List<Tool>? tools,
ToolConfig? toolConfig,
}) {
safetySettings ??= _safetySettings;
generationConfig ??= _generationConfig;
tools ??= _tools;
toolConfig ??= _toolConfig;
final parameters = <String, Object?>{
'contents': prompt.map((p) => p.toJson()).toList(),
if (safetySettings.isNotEmpty)
'safetySettings': safetySettings.map((s) => s.toJson()).toList(),
if (generationConfig != null)
'generationConfig': generationConfig.toJson(),
if (tools != null) 'tools': tools.map((t) => t.toJson()).toList(),
if (toolConfig != null) 'toolConfig': toolConfig.toJson(),
if (_systemInstruction case final systemInstruction?)
'systemInstruction': systemInstruction.toJson(),
};
final response =
_client.streamRequest(_taskUri(Task.streamGenerateContent), parameters);
final response = _client.streamRequest(
_taskUri(Task.streamGenerateContent),
_generateContentRequest(
prompt,
safetySettings: safetySettings,
generationConfig: generationConfig,
tools: tools,
toolConfig: toolConfig,
));
return response.map(parseGenerateContentResponse);
}

Expand All @@ -237,6 +221,11 @@ final class GenerativeModel {
/// Sends a "countTokens" API request for the configured model,
/// and waits for the response.
///
/// The [safetySettings], [generationConfig], [tools], and [toolConfig],
/// override the arguments of the same name passed to the
/// [GenerativeModel.new] constructor. Each argument, when non-null,
/// overrides the model level configuration in its entirety.
///
/// Example:
/// ```dart
/// final promptContent = [Content.text(prompt)];
Expand All @@ -249,12 +238,22 @@ final class GenerativeModel {
/// print(response.text);
/// }
/// ```
Future<CountTokensResponse> countTokens(Iterable<Content> contents) async {
final parameters = <String, Object?>{
'contents': contents.map((c) => c.toJson()).toList()
};
final response =
await _client.makeRequest(_taskUri(Task.countTokens), parameters);
Future<CountTokensResponse> countTokens(
Iterable<Content> contents, {
List<SafetySetting>? safetySettings,
GenerationConfig? generationConfig,
List<Tool>? tools,
ToolConfig? toolConfig,
}) async {
final response = await _client.makeRequest(_taskUri(Task.countTokens), {
'generateContentRequest': _generateContentRequest(
contents,
safetySettings: safetySettings,
generationConfig: generationConfig,
tools: tools,
toolConfig: toolConfig,
)
});
return parseCountTokensResponse(response);
}

Expand Down Expand Up @@ -307,6 +306,31 @@ final class GenerativeModel {
_taskUri(Task.batchEmbedContents), parameters);
return parseBatchEmbedContentsResponse(response);
}

Map<String, Object?> _generateContentRequest(
Iterable<Content> contents, {
List<SafetySetting>? safetySettings,
GenerationConfig? generationConfig,
List<Tool>? tools,
ToolConfig? toolConfig,
}) {
safetySettings ??= _safetySettings;
generationConfig ??= _generationConfig;
tools ??= _tools;
toolConfig ??= _toolConfig;
return {
'model': '${_model.prefix}/${_model.name}',
'contents': contents.map((c) => c.toJson()).toList(),
if (safetySettings.isNotEmpty)
'safetySettings': safetySettings.map((s) => s.toJson()).toList(),
if (generationConfig != null)
'generationConfig': generationConfig.toJson(),
if (tools != null) 'tools': tools.map((t) => t.toJson()).toList(),
if (toolConfig != null) 'toolConfig': toolConfig.toJson(),
if (_systemInstruction case final systemInstruction?)
'systemInstruction': systemInstruction.toJson(),
};
}
}

/// Creates a model with an overridden [ApiClient] for testing.
Expand Down
87 changes: 79 additions & 8 deletions pkgs/google_generative_ai/test/generative_model_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ void main() {
),
);
expect(request, {
'model': 'models/$defaultModelName',
'contents': [
{
'role': 'user',
Expand Down Expand Up @@ -311,6 +312,7 @@ void main() {
),
);
expect(request, {
'model': 'models/$defaultModelName',
'contents': [
{
'role': 'user',
Expand Down Expand Up @@ -415,20 +417,89 @@ void main() {
),
);
expect(request, {
'contents': [
{
'role': 'user',
'parts': [
{'text': prompt},
],
},
],
'generateContentRequest': {
'model': 'models/$defaultModelName',
'contents': [
{
'role': 'user',
'parts': [
{'text': prompt},
],
},
],
}
});
},
response: {'totalTokens': 2},
);
expect(response, matchesCountTokensResponse(CountTokensResponse(2)));
});

test('can override GenerateContentRequest fields', () async {
final (client, model) = createModel();
final prompt = 'Some prompt';
await client.checkRequest(
response: {'totalTokens': 100},
() => model.countTokens(
[Content.text(prompt)],
safetySettings: [
SafetySetting(
HarmCategory.dangerousContent,
HarmBlockThreshold.high,
),
],
generationConfig: GenerationConfig(stopSequences: ['a']),
tools: [
Tool(functionDeclarations: [
FunctionDeclaration(
'someFunction',
'Some cool function.',
Schema(SchemaType.string, description: 'Some parameter.'),
),
]),
],
toolConfig: ToolConfig(
functionCallingConfig: FunctionCallingConfig(
mode: FunctionCallingMode.any,
allowedFunctionNames: {'someFunction'},
),
),
),
verifyRequest: (_, countTokensRequest) {
final request = countTokensRequest['generateContentRequest']
as Map<String, Object?>;
expect(request['safetySettings'], [
{
'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',
'threshold': 'BLOCK_ONLY_HIGH',
},
]);
expect(request['generationConfig'], {
'stopSequences': ['a'],
});
expect(request['tools'], [
{
'functionDeclarations': [
{
'name': 'someFunction',
'description': 'Some cool function.',
'parameters': {
'type': 'STRING',
'description': 'Some parameter.',
},
},
],
},
]);
expect(request['toolConfig'], {
'functionCallingConfig': {
'mode': 'ANY',
'allowedFunctionNames': ['someFunction'],
},
});
},
);
});
});

group('embed content', () {
Expand Down
8 changes: 7 additions & 1 deletion samples/dart/bin/simple_text.dart
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@ void main() async {
stderr.writeln(r'No $GOOGLE_API_KEY environment variable');
exit(1);
}
final model = GenerativeModel(model: 'gemini-pro', apiKey: apiKey);
final model = GenerativeModel(
model: 'gemini-pro',
apiKey: apiKey,
safetySettings: [
SafetySetting(HarmCategory.dangerousContent, HarmBlockThreshold.high)
],
generationConfig: GenerationConfig(maxOutputTokens: 200));
final prompt = 'Write a story about a magic backpack.';
print('Prompt: $prompt');
final content = [Content.text(prompt)];
Expand Down

0 comments on commit 1f2b428

Please sign in to comment.