diff --git a/Sources/GoogleAI/GenerateContentResponse.swift b/Sources/GoogleAI/GenerateContentResponse.swift index 4b639c0..6ac4576 100644 --- a/Sources/GoogleAI/GenerateContentResponse.swift +++ b/Sources/GoogleAI/GenerateContentResponse.swift @@ -174,7 +174,7 @@ public struct CitationMetadata: Decodable { /// A struct describing a source attribution. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) -public struct Citation: Decodable { +public struct Citation { /// The inclusive beginning of a sequence in a model response that derives from a cited source. public let startIndex: Int @@ -297,3 +297,23 @@ extension PromptFeedback: Decodable { } } } + +// MARK: - Codable Conformances + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +extension Citation: Decodable { + enum CodingKeys: CodingKey { + case startIndex + case endIndex + case uri + case license + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + startIndex = try container.decodeIfPresent(Int.self, forKey: .startIndex) ?? 0 + endIndex = try container.decode(Int.self, forKey: .endIndex) + uri = try container.decode(String.self, forKey: .uri) + license = try container.decodeIfPresent(String.self, forKey: .license) ?? "" + } +} diff --git a/Tests/GoogleAITests/GenerativeModelTests.swift b/Tests/GoogleAITests/GenerativeModelTests.swift index 2b2f8ee..d5efcd8 100644 --- a/Tests/GoogleAITests/GenerativeModelTests.swift +++ b/Tests/GoogleAITests/GenerativeModelTests.swift @@ -104,12 +104,27 @@ final class GenerativeModelTests: XCTestCase { XCTAssertEqual(candidate.content.parts.count, 1) XCTAssertEqual(response.text, "Some information cited from an external source") let citationMetadata = try XCTUnwrap(candidate.citationMetadata) - XCTAssertEqual(citationMetadata.citationSources.count, 1) - let citationSource = try XCTUnwrap(citationMetadata.citationSources.first) - XCTAssertEqual(citationSource.uri, "https://www.example.com/some-citation") - XCTAssertEqual(citationSource.startIndex, 574) - XCTAssertEqual(citationSource.endIndex, 705) - XCTAssertEqual(citationSource.license, "") + XCTAssertEqual(citationMetadata.citationSources.count, 4) + let citationSource1 = try XCTUnwrap(citationMetadata.citationSources[0]) + XCTAssertEqual(citationSource1.uri, "https://www.example.com/some-citation-1") + XCTAssertEqual(citationSource1.startIndex, 0) + XCTAssertEqual(citationSource1.endIndex, 128) + XCTAssertEqual(citationSource1.license, "") + let citationSource2 = try XCTUnwrap(citationMetadata.citationSources[1]) + XCTAssertEqual(citationSource2.uri, "https://www.example.com/some-citation-2") + XCTAssertEqual(citationSource2.startIndex, 130) + XCTAssertEqual(citationSource2.endIndex, 265) + XCTAssertEqual(citationSource2.license, "") + let citationSource3 = try XCTUnwrap(citationMetadata.citationSources[2]) + XCTAssertEqual(citationSource3.uri, "https://www.example.com/some-citation-3") + XCTAssertEqual(citationSource3.startIndex, 272) + XCTAssertEqual(citationSource3.endIndex, 431) + XCTAssertEqual(citationSource3.license, "") + let citationSource4 = try XCTUnwrap(citationMetadata.citationSources[3]) + XCTAssertEqual(citationSource4.uri, "https://www.example.com/some-citation-4") + XCTAssertEqual(citationSource4.startIndex, 444) + XCTAssertEqual(citationSource4.endIndex, 630) + XCTAssertEqual(citationSource4.license, "mit") } func testGenerateContent_success_quoteReply() async throws { @@ -724,9 +739,21 @@ final class GenerativeModelTests: XCTestCase { XCTAssertEqual(citations.count, 8) XCTAssertTrue(citations - .contains(where: { $0.startIndex == 574 && $0.endIndex == 705 && !$0.uri.isEmpty })) + .contains(where: { + $0.startIndex == 0 && $0.endIndex == 128 && !$0.uri.isEmpty && $0.license.isEmpty + })) XCTAssertTrue(citations - .contains(where: { $0.startIndex == 899 && $0.endIndex == 1026 && !$0.uri.isEmpty })) + .contains(where: { + $0.startIndex == 130 && $0.endIndex == 265 && !$0.uri.isEmpty && $0.license.isEmpty + })) + XCTAssertTrue(citations + .contains(where: { + $0.startIndex == 272 && $0.endIndex == 431 && !$0.uri.isEmpty && $0.license.isEmpty + })) + XCTAssertTrue(citations + .contains(where: { + $0.startIndex == 444 && $0.endIndex == 630 && !$0.uri.isEmpty && $0.license == "mit" + })) } func testGenerateContentStream_errorMidStream() async throws {