From 7afcf897889b6e3e5c6bd51ad29048dc60539643 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Wed, 8 May 2024 16:31:59 -0400 Subject: [PATCH] Make `text` computed property handle mixed-parts responses (#165) --- .../GoogleAI/GenerateContentResponse.swift | 10 ++++- ...y-success-function-call-mixed-content.json | 37 +++++++++++++++++ ...-success-function-call-parallel-calls.json | 40 +++++++++++++++++++ .../GoogleAITests/GenerativeModelTests.swift | 34 ++++++++++++++++ 4 files changed, 119 insertions(+), 2 deletions(-) create mode 100644 Tests/GoogleAITests/GenerateContentResponses/unary-success-function-call-mixed-content.json create mode 100644 Tests/GoogleAITests/GenerateContentResponses/unary-success-function-call-parallel-calls.json diff --git a/Sources/GoogleAI/GenerateContentResponse.swift b/Sources/GoogleAI/GenerateContentResponse.swift index 683df0c..04c41f7 100644 --- a/Sources/GoogleAI/GenerateContentResponse.swift +++ b/Sources/GoogleAI/GenerateContentResponse.swift @@ -45,11 +45,17 @@ public struct GenerateContentResponse { Logging.default.error("Could not get text from a response that had no candidates.") return nil } - guard let text = candidate.content.parts.first?.text else { + let textValues: [String] = candidate.content.parts.compactMap { part in + guard case let .text(text) = part else { + return nil + } + return text + } + guard textValues.count > 0 else { Logging.default.error("Could not get a text part from the first candidate.") return nil } - return text + return textValues.joined(separator: " ") } /// Returns function calls found in any `Part`s of the first candidate of the response, if any. diff --git a/Tests/GoogleAITests/GenerateContentResponses/unary-success-function-call-mixed-content.json b/Tests/GoogleAITests/GenerateContentResponses/unary-success-function-call-mixed-content.json new file mode 100644 index 0000000..6e7ce27 --- /dev/null +++ b/Tests/GoogleAITests/GenerateContentResponses/unary-success-function-call-mixed-content.json @@ -0,0 +1,37 @@ +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "The sum of [1, 2," + }, + { + "functionCall": { + "name": "sum", + "args": { + "y": 1, + "x": 2 + } + } + }, + { + "text": "3] is" + }, + { + "functionCall": { + "name": "sum", + "args": { + "y": 3, + "x": 3 + } + } + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0 + } + ] +} diff --git a/Tests/GoogleAITests/GenerateContentResponses/unary-success-function-call-parallel-calls.json b/Tests/GoogleAITests/GenerateContentResponses/unary-success-function-call-parallel-calls.json new file mode 100644 index 0000000..d535f8e --- /dev/null +++ b/Tests/GoogleAITests/GenerateContentResponses/unary-success-function-call-parallel-calls.json @@ -0,0 +1,40 @@ +{ + "candidates": [ + { + "content": { + "parts": [ + { + "functionCall": { + "name": "sum", + "args": { + "y": 1, + "x": 2 + } + } + }, + { + "functionCall": { + "name": "sum", + "args": { + "y": 3, + "x": 4 + } + } + }, + { + "functionCall": { + "name": "sum", + "args": { + "y": 5, + "x": 6 + } + } + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0 + } + ] +} diff --git a/Tests/GoogleAITests/GenerativeModelTests.swift b/Tests/GoogleAITests/GenerativeModelTests.swift index 9ed3401..ccd8979 100644 --- a/Tests/GoogleAITests/GenerativeModelTests.swift +++ b/Tests/GoogleAITests/GenerativeModelTests.swift @@ -254,6 +254,40 @@ final class GenerativeModelTests: XCTestCase { XCTAssertEqual(response.functionCalls, [functionCall]) } + func testGenerateContent_success_functionCall_parallelCalls() async throws { + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "unary-success-function-call-parallel-calls", + withExtension: "json" + ) + + let response = try await model.generateContent(testPrompt) + + XCTAssertEqual(response.candidates.count, 1) + let candidate = try XCTUnwrap(response.candidates.first) + XCTAssertEqual(candidate.content.parts.count, 3) + let functionCalls = response.functionCalls + XCTAssertEqual(functionCalls.count, 3) + } + + func testGenerateContent_success_functionCall_mixedContent() async throws { + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "unary-success-function-call-mixed-content", + withExtension: "json" + ) + + let response = try await model.generateContent(testPrompt) + + XCTAssertEqual(response.candidates.count, 1) + let candidate = try XCTUnwrap(response.candidates.first) + XCTAssertEqual(candidate.content.parts.count, 4) + let functionCalls = response.functionCalls + XCTAssertEqual(functionCalls.count, 2) + let text = try XCTUnwrap(response.text) + XCTAssertEqual(text, "The sum of [1, 2, 3] is") + } + func testGenerateContent_usageMetadata() async throws { MockURLProtocol .requestHandler = try httpRequestHandler(