From ab4f33c18868f999be4630f6d12bed616f860bc1 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Tue, 2 Apr 2024 22:50:29 +0000 Subject: [PATCH] Add functionCalls accessor to GenerateContentResponse (#123) --- Sources/GoogleAI/GenerateContentResponse.swift | 13 +++++++++++++ Tests/GoogleAITests/GenerativeModelTests.swift | 5 +++++ 2 files changed, 18 insertions(+) diff --git a/Sources/GoogleAI/GenerateContentResponse.swift b/Sources/GoogleAI/GenerateContentResponse.swift index 03153da..5cfa5ae 100644 --- a/Sources/GoogleAI/GenerateContentResponse.swift +++ b/Sources/GoogleAI/GenerateContentResponse.swift @@ -37,6 +37,19 @@ public struct GenerateContentResponse { return text } + /// Returns function calls found in any `Part`s of the first candidate of the response, if any. + public var functionCalls: [FunctionCall] { + guard let candidate = candidates.first else { + return [] + } + return candidate.content.parts.compactMap { part in + guard case let .functionCall(functionCall) = part else { + return nil + } + return functionCall + } + } + /// Initializer for SwiftUI previews or tests. public init(candidates: [CandidateResponse], promptFeedback: PromptFeedback?) { self.candidates = candidates diff --git a/Tests/GoogleAITests/GenerativeModelTests.swift b/Tests/GoogleAITests/GenerativeModelTests.swift index 5cd14d9..2b2f8ee 100644 --- a/Tests/GoogleAITests/GenerativeModelTests.swift +++ b/Tests/GoogleAITests/GenerativeModelTests.swift @@ -63,6 +63,7 @@ final class GenerativeModelTests: XCTestCase { let promptFeedback = try XCTUnwrap(response.promptFeedback) XCTAssertNil(promptFeedback.blockReason) XCTAssertEqual(promptFeedback.safetyRatings, safetyRatingsNegligible) + XCTAssertEqual(response.functionCalls, []) } func testGenerateContent_success_basicReplyShort() async throws { @@ -86,6 +87,7 @@ final class GenerativeModelTests: XCTestCase { let promptFeedback = try XCTUnwrap(response.promptFeedback) XCTAssertNil(promptFeedback.blockReason) XCTAssertEqual(promptFeedback.safetyRatings, safetyRatingsNegligible) + XCTAssertEqual(response.functionCalls, []) } func testGenerateContent_success_citations() async throws { @@ -188,6 +190,7 @@ final class GenerativeModelTests: XCTestCase { } XCTAssertEqual(functionCall.name, "current_time") XCTAssertTrue(functionCall.args.isEmpty) + XCTAssertEqual(response.functionCalls, [functionCall]) } func testGenerateContent_success_functionCall_noArguments() async throws { @@ -209,6 +212,7 @@ final class GenerativeModelTests: XCTestCase { } XCTAssertEqual(functionCall.name, "current_time") XCTAssertTrue(functionCall.args.isEmpty) + XCTAssertEqual(response.functionCalls, [functionCall]) } func testGenerateContent_success_functionCall_withArguments() async throws { @@ -234,6 +238,7 @@ final class GenerativeModelTests: XCTestCase { XCTAssertEqual(argX, .number(4)) let argY = try XCTUnwrap(functionCall.args["y"]) XCTAssertEqual(argY, .number(5)) + XCTAssertEqual(response.functionCalls, [functionCall]) } func testGenerateContent_failure_invalidAPIKey() async throws {