diff --git a/pkgs/google_generative_ai/CHANGELOG.md b/pkgs/google_generative_ai/CHANGELOG.md index 416a7f5..f5c3e61 100644 --- a/pkgs/google_generative_ai/CHANGELOG.md +++ b/pkgs/google_generative_ai/CHANGELOG.md @@ -4,6 +4,11 @@ replies with more than one part. - Fix handling of `format` argument to `Schema.number` and `Schema.integer`. - Export `UsageMetadata`. +- Include the full `GenerateContentRequest` (previously omitted + `safetySettings`, `generationConfig`, `tools`, `toolConfig`, and + `systemInstruction`) in `countTokens` requests. This aligns the token count + with the token count the backend will see in practice for a + `generateContent` request. ## 0.4.0 diff --git a/pkgs/google_generative_ai/lib/src/model.dart b/pkgs/google_generative_ai/lib/src/model.dart index 79be8f1..0ec1912 100644 --- a/pkgs/google_generative_ai/lib/src/model.dart +++ b/pkgs/google_generative_ai/lib/src/model.dart @@ -168,23 +168,15 @@ final class GenerativeModel { List? tools, ToolConfig? toolConfig, }) async { - safetySettings ??= _safetySettings; - generationConfig ??= _generationConfig; - tools ??= _tools; - toolConfig ??= _toolConfig; - final parameters = { - 'contents': prompt.map((p) => p.toJson()).toList(), - if (safetySettings.isNotEmpty) - 'safetySettings': safetySettings.map((s) => s.toJson()).toList(), - if (generationConfig != null) - 'generationConfig': generationConfig.toJson(), - if (tools != null) 'tools': tools.map((t) => t.toJson()).toList(), - if (toolConfig != null) 'toolConfig': toolConfig.toJson(), - if (_systemInstruction case final systemInstruction?) - 'systemInstruction': systemInstruction.toJson(), - }; - final response = - await _client.makeRequest(_taskUri(Task.generateContent), parameters); + final response = await _client.makeRequest( + _taskUri(Task.generateContent), + _generateContentRequest( + prompt, + safetySettings: safetySettings, + generationConfig: generationConfig, + tools: tools, + toolConfig: toolConfig, + )); return parseGenerateContentResponse(response); } @@ -212,23 +204,15 @@ final class GenerativeModel { List? tools, ToolConfig? toolConfig, }) { - safetySettings ??= _safetySettings; - generationConfig ??= _generationConfig; - tools ??= _tools; - toolConfig ??= _toolConfig; - final parameters = { - 'contents': prompt.map((p) => p.toJson()).toList(), - if (safetySettings.isNotEmpty) - 'safetySettings': safetySettings.map((s) => s.toJson()).toList(), - if (generationConfig != null) - 'generationConfig': generationConfig.toJson(), - if (tools != null) 'tools': tools.map((t) => t.toJson()).toList(), - if (toolConfig != null) 'toolConfig': toolConfig.toJson(), - if (_systemInstruction case final systemInstruction?) - 'systemInstruction': systemInstruction.toJson(), - }; - final response = - _client.streamRequest(_taskUri(Task.streamGenerateContent), parameters); + final response = _client.streamRequest( + _taskUri(Task.streamGenerateContent), + _generateContentRequest( + prompt, + safetySettings: safetySettings, + generationConfig: generationConfig, + tools: tools, + toolConfig: toolConfig, + )); return response.map(parseGenerateContentResponse); } @@ -237,6 +221,11 @@ final class GenerativeModel { /// Sends a "countTokens" API request for the configured model, /// and waits for the response. /// + /// The [safetySettings], [generationConfig], [tools], and [toolConfig], + /// override the arguments of the same name passed to the + /// [GenerativeModel.new] constructor. Each argument, when non-null, + /// overrides the model level configuration in its entirety. + /// /// Example: /// ```dart /// final promptContent = [Content.text(prompt)]; @@ -249,12 +238,22 @@ final class GenerativeModel { /// print(response.text); /// } /// ``` - Future countTokens(Iterable contents) async { - final parameters = { - 'contents': contents.map((c) => c.toJson()).toList() - }; - final response = - await _client.makeRequest(_taskUri(Task.countTokens), parameters); + Future countTokens( + Iterable contents, { + List? safetySettings, + GenerationConfig? generationConfig, + List? tools, + ToolConfig? toolConfig, + }) async { + final response = await _client.makeRequest(_taskUri(Task.countTokens), { + 'generateContentRequest': _generateContentRequest( + contents, + safetySettings: safetySettings, + generationConfig: generationConfig, + tools: tools, + toolConfig: toolConfig, + ) + }); return parseCountTokensResponse(response); } @@ -307,6 +306,31 @@ final class GenerativeModel { _taskUri(Task.batchEmbedContents), parameters); return parseBatchEmbedContentsResponse(response); } + + Map _generateContentRequest( + Iterable contents, { + List? safetySettings, + GenerationConfig? generationConfig, + List? tools, + ToolConfig? toolConfig, + }) { + safetySettings ??= _safetySettings; + generationConfig ??= _generationConfig; + tools ??= _tools; + toolConfig ??= _toolConfig; + return { + 'model': '${_model.prefix}/${_model.name}', + 'contents': contents.map((c) => c.toJson()).toList(), + if (safetySettings.isNotEmpty) + 'safetySettings': safetySettings.map((s) => s.toJson()).toList(), + if (generationConfig != null) + 'generationConfig': generationConfig.toJson(), + if (tools != null) 'tools': tools.map((t) => t.toJson()).toList(), + if (toolConfig != null) 'toolConfig': toolConfig.toJson(), + if (_systemInstruction case final systemInstruction?) + 'systemInstruction': systemInstruction.toJson(), + }; + } } /// Creates a model with an overridden [ApiClient] for testing. diff --git a/pkgs/google_generative_ai/test/generative_model_test.dart b/pkgs/google_generative_ai/test/generative_model_test.dart index a3acc3d..3a58b3b 100644 --- a/pkgs/google_generative_ai/test/generative_model_test.dart +++ b/pkgs/google_generative_ai/test/generative_model_test.dart @@ -99,6 +99,7 @@ void main() { ), ); expect(request, { + 'model': 'models/$defaultModelName', 'contents': [ { 'role': 'user', @@ -311,6 +312,7 @@ void main() { ), ); expect(request, { + 'model': 'models/$defaultModelName', 'contents': [ { 'role': 'user', @@ -415,20 +417,89 @@ void main() { ), ); expect(request, { - 'contents': [ - { - 'role': 'user', - 'parts': [ - {'text': prompt}, - ], - }, - ], + 'generateContentRequest': { + 'model': 'models/$defaultModelName', + 'contents': [ + { + 'role': 'user', + 'parts': [ + {'text': prompt}, + ], + }, + ], + } }); }, response: {'totalTokens': 2}, ); expect(response, matchesCountTokensResponse(CountTokensResponse(2))); }); + + test('can override GenerateContentRequest fields', () async { + final (client, model) = createModel(); + final prompt = 'Some prompt'; + await client.checkRequest( + response: {'totalTokens': 100}, + () => model.countTokens( + [Content.text(prompt)], + safetySettings: [ + SafetySetting( + HarmCategory.dangerousContent, + HarmBlockThreshold.high, + ), + ], + generationConfig: GenerationConfig(stopSequences: ['a']), + tools: [ + Tool(functionDeclarations: [ + FunctionDeclaration( + 'someFunction', + 'Some cool function.', + Schema(SchemaType.string, description: 'Some parameter.'), + ), + ]), + ], + toolConfig: ToolConfig( + functionCallingConfig: FunctionCallingConfig( + mode: FunctionCallingMode.any, + allowedFunctionNames: {'someFunction'}, + ), + ), + ), + verifyRequest: (_, countTokensRequest) { + final request = countTokensRequest['generateContentRequest'] + as Map; + expect(request['safetySettings'], [ + { + 'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', + 'threshold': 'BLOCK_ONLY_HIGH', + }, + ]); + expect(request['generationConfig'], { + 'stopSequences': ['a'], + }); + expect(request['tools'], [ + { + 'functionDeclarations': [ + { + 'name': 'someFunction', + 'description': 'Some cool function.', + 'parameters': { + 'type': 'STRING', + 'description': 'Some parameter.', + }, + }, + ], + }, + ]); + expect(request['toolConfig'], { + 'functionCallingConfig': { + 'mode': 'ANY', + 'allowedFunctionNames': ['someFunction'], + }, + }); + }, + ); + }); }); group('embed content', () { diff --git a/samples/dart/bin/simple_text.dart b/samples/dart/bin/simple_text.dart index 4be105f..558355a 100644 --- a/samples/dart/bin/simple_text.dart +++ b/samples/dart/bin/simple_text.dart @@ -22,7 +22,13 @@ void main() async { stderr.writeln(r'No $GOOGLE_API_KEY environment variable'); exit(1); } - final model = GenerativeModel(model: 'gemini-pro', apiKey: apiKey); + final model = GenerativeModel( + model: 'gemini-pro', + apiKey: apiKey, + safetySettings: [ + SafetySetting(HarmCategory.dangerousContent, HarmBlockThreshold.high) + ], + generationConfig: GenerationConfig(maxOutputTokens: 200)); final prompt = 'Write a story about a magic backpack.'; print('Prompt: $prompt'); final content = [Content.text(prompt)];