From a9ebd46b1d8f9f6f69604174d74c1bc4c0eac070 Mon Sep 17 00:00:00 2001 From: Nate Bosch Date: Fri, 10 May 2024 15:15:15 -0700 Subject: [PATCH] Add Content.functionResponses utility (#159) The alternative when replying to multiple functions in parallel is to use the `Content` constructor and pass the string `'function'` to the `role` parameter. Use a static method to obscure the role string for consistency with the others. Update sample with a demonstration of replying to all function calls. --- pkgs/google_generative_ai/CHANGELOG.md | 2 + .../google_generative_ai/lib/src/content.dart | 2 + samples/dart/bin/function_calling.dart | 60 ++++++++++--------- 3 files changed, 37 insertions(+), 27 deletions(-) diff --git a/pkgs/google_generative_ai/CHANGELOG.md b/pkgs/google_generative_ai/CHANGELOG.md index 46dcf67..f839bb4 100644 --- a/pkgs/google_generative_ai/CHANGELOG.md +++ b/pkgs/google_generative_ai/CHANGELOG.md @@ -6,6 +6,8 @@ `'application/json'` to force the model to reply with JSON parseable output. - Add `outputDimensionality` argument support for `embedContent` and `batchEmbedContent`. +- Add `Content.functionResponses` utility to reply to multiple function calls in + parallel. - **Breaking** The `Part` class is no longer `sealed`. Exhaustive switches over a `Part` instance will need to add a wildcard case. diff --git a/pkgs/google_generative_ai/lib/src/content.dart b/pkgs/google_generative_ai/lib/src/content.dart index f63af6b..5f462ac 100644 --- a/pkgs/google_generative_ai/lib/src/content.dart +++ b/pkgs/google_generative_ai/lib/src/content.dart @@ -38,6 +38,8 @@ final class Content { static Content functionResponse( String name, Map? response) => Content('function', [FunctionResponse(name, response)]); + static Content functionResponses(Iterable responses) => + Content('function', responses.toList()); static Content system(String instructions) => Content('system', [TextPart(instructions)]); diff --git a/samples/dart/bin/function_calling.dart b/samples/dart/bin/function_calling.dart index e34e736..689ccd3 100644 --- a/samples/dart/bin/function_calling.dart +++ b/samples/dart/bin/function_calling.dart @@ -23,7 +23,7 @@ void main() async { exit(1); } final model = GenerativeModel( - model: 'gemini-pro', + model: 'gemini-1.5-pro-latest', apiKey: apiKey, tools: [ Tool(functionDeclarations: [ @@ -31,9 +31,8 @@ void main() async { 'fetchCurrentWeather', 'Returns the weather in a given location.', Schema(SchemaType.object, properties: { - 'location': Schema(SchemaType.string), - 'unit': Schema(SchemaType.string, - enumValues: ['celcius', 'farenheit']) + 'location': + Schema.string(description: 'A location name, like "London".'), }, requiredProperties: [ 'location' ])) @@ -41,34 +40,36 @@ void main() async { ], ); - final prompt = 'What is the weather in Seattle?'; + final prompt = + "I'm trying to decide whether to go to London or Zurich this weekend. " + 'How hot are those cities? How about Singapore? Or maybe Tokyo. ' + 'I want to go somewhere not that cold but not too hot either. ' + 'Suggest a destination.'; final content = [Content.text(prompt)]; - final response = await model.generateContent(content); + var response = await model.generateContent(content); - final functionCalls = response.functionCalls.toList(); - if (functionCalls.isEmpty) { - print('No function calls.'); - print(response.text); - } else if (functionCalls.length > 1) { - print('Too many function calls.'); - print(response.text); - } else { + List functionCalls; + while ((functionCalls = response.functionCalls.toList()).isNotEmpty) { + var responses = [ + for (final functionCall in functionCalls) + _dispatchFunctionCall(functionCall) + ]; content ..add(response.candidates.first.content) - ..add(_dispatchFunctionCall(functionCalls.single)); - final nextResponse = await model.generateContent(content); - print('Response: ${nextResponse.text}'); + ..add(Content.functionResponses(responses)); + response = await model.generateContent(content); } + print('Response: ${response.text}'); } -Content _dispatchFunctionCall(FunctionCall call) { +FunctionResponse _dispatchFunctionCall(FunctionCall call) { final result = switch (call.name) { 'fetchCurrentWeather' => { 'weather': _fetchWeather(WeatherRequest._parse(call.args)) }, _ => throw UnimplementedError('Function not implemented: ${call.name}') }; - return Content.functionResponse(call.name, result); + return FunctionResponse(call.name, result); } class WeatherRequest { @@ -80,14 +81,19 @@ class WeatherRequest { }; final String location; WeatherRequest(this.location); + + @override + String toString() => {'location': location}.toString(); } -String _fetchWeather(WeatherRequest request) { - const weather = { - 'Seattle': 'rainy', - 'Chicago': 'windy', - 'Sunnyvale': 'sunny' - }; - final location = request.location; - return weather[location] ?? 'who knows?'; +var _responseIndex = -1; +Map _fetchWeather(WeatherRequest request) { + const responses = >[ + {'condition': 'sunny', 'temp_c': -23.9}, + {'condition': 'extreme rainstorm', 'temp_c': 13.9}, + {'condition': 'cloudy', 'temp_c': 33.9}, + {'condition': 'moderate', 'temp_c': 19.9}, + ]; + _responseIndex = (_responseIndex + 1) % responses.length; + return responses[_responseIndex]; }