Skip to content

Commit

Permalink
Update Citation decoding to handle optional values (google-gemini#135)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored and G.Dev.Ssomsak committed Jun 21, 2024
1 parent c70f616 commit 9ac3b2e
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 9 deletions.
22 changes: 21 additions & 1 deletion Sources/GoogleAI/GenerateContentResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ public struct CitationMetadata: Decodable {

/// A struct describing a source attribution.
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
public struct Citation: Decodable {
public struct Citation {
/// The inclusive beginning of a sequence in a model response that derives from a cited source.
public let startIndex: Int

Expand Down Expand Up @@ -297,3 +297,23 @@ extension PromptFeedback: Decodable {
}
}
}

// MARK: - Codable Conformances

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
extension Citation: Decodable {
enum CodingKeys: CodingKey {
case startIndex
case endIndex
case uri
case license
}

public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
startIndex = try container.decodeIfPresent(Int.self, forKey: .startIndex) ?? 0
endIndex = try container.decode(Int.self, forKey: .endIndex)
uri = try container.decode(String.self, forKey: .uri)
license = try container.decodeIfPresent(String.self, forKey: .license) ?? ""
}
}
43 changes: 35 additions & 8 deletions Tests/GoogleAITests/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,27 @@ final class GenerativeModelTests: XCTestCase {
XCTAssertEqual(candidate.content.parts.count, 1)
XCTAssertEqual(response.text, "Some information cited from an external source")
let citationMetadata = try XCTUnwrap(candidate.citationMetadata)
XCTAssertEqual(citationMetadata.citationSources.count, 1)
let citationSource = try XCTUnwrap(citationMetadata.citationSources.first)
XCTAssertEqual(citationSource.uri, "https://www.example.com/some-citation")
XCTAssertEqual(citationSource.startIndex, 574)
XCTAssertEqual(citationSource.endIndex, 705)
XCTAssertEqual(citationSource.license, "")
XCTAssertEqual(citationMetadata.citationSources.count, 4)
let citationSource1 = try XCTUnwrap(citationMetadata.citationSources[0])
XCTAssertEqual(citationSource1.uri, "https://www.example.com/some-citation-1")
XCTAssertEqual(citationSource1.startIndex, 0)
XCTAssertEqual(citationSource1.endIndex, 128)
XCTAssertEqual(citationSource1.license, "")
let citationSource2 = try XCTUnwrap(citationMetadata.citationSources[1])
XCTAssertEqual(citationSource2.uri, "https://www.example.com/some-citation-2")
XCTAssertEqual(citationSource2.startIndex, 130)
XCTAssertEqual(citationSource2.endIndex, 265)
XCTAssertEqual(citationSource2.license, "")
let citationSource3 = try XCTUnwrap(citationMetadata.citationSources[2])
XCTAssertEqual(citationSource3.uri, "https://www.example.com/some-citation-3")
XCTAssertEqual(citationSource3.startIndex, 272)
XCTAssertEqual(citationSource3.endIndex, 431)
XCTAssertEqual(citationSource3.license, "")
let citationSource4 = try XCTUnwrap(citationMetadata.citationSources[3])
XCTAssertEqual(citationSource4.uri, "https://www.example.com/some-citation-4")
XCTAssertEqual(citationSource4.startIndex, 444)
XCTAssertEqual(citationSource4.endIndex, 630)
XCTAssertEqual(citationSource4.license, "mit")
}

func testGenerateContent_success_quoteReply() async throws {
Expand Down Expand Up @@ -724,9 +739,21 @@ final class GenerativeModelTests: XCTestCase {

XCTAssertEqual(citations.count, 8)
XCTAssertTrue(citations
.contains(where: { $0.startIndex == 574 && $0.endIndex == 705 && !$0.uri.isEmpty }))
.contains(where: {
$0.startIndex == 0 && $0.endIndex == 128 && !$0.uri.isEmpty && $0.license.isEmpty
}))
XCTAssertTrue(citations
.contains(where: { $0.startIndex == 899 && $0.endIndex == 1026 && !$0.uri.isEmpty }))
.contains(where: {
$0.startIndex == 130 && $0.endIndex == 265 && !$0.uri.isEmpty && $0.license.isEmpty
}))
XCTAssertTrue(citations
.contains(where: {
$0.startIndex == 272 && $0.endIndex == 431 && !$0.uri.isEmpty && $0.license.isEmpty
}))
XCTAssertTrue(citations
.contains(where: {
$0.startIndex == 444 && $0.endIndex == 630 && !$0.uri.isEmpty && $0.license == "mit"
}))
}

func testGenerateContentStream_errorMidStream() async throws {
Expand Down

0 comments on commit 9ac3b2e

Please sign in to comment.