From bd5353cf4f5bda9abfef1f1592e0500c2ec24e08 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Wed, 13 Mar 2024 17:59:19 -0400 Subject: [PATCH] Wrap internal GenerativeModel --- Sources/GoogleAI/GenerativeModel.swift | 137 +++++--------- Sources/Internal/Errors.swift | 2 +- Sources/Internal/GenerateContentError.swift | 28 +++ Sources/Internal/GenerativeModel.swift | 175 ++++++++++++++++++ .../GoogleAITests/GenerativeModelTests.swift | 67 +++---- 5 files changed, 288 insertions(+), 121 deletions(-) create mode 100644 Sources/Internal/GenerateContentError.swift create mode 100644 Sources/Internal/GenerativeModel.swift diff --git a/Sources/GoogleAI/GenerativeModel.swift b/Sources/GoogleAI/GenerativeModel.swift index d80131a..dedb02a 100644 --- a/Sources/GoogleAI/GenerativeModel.swift +++ b/Sources/GoogleAI/GenerativeModel.swift @@ -22,23 +22,10 @@ public final class GenerativeModel { // The prefix for a model resource in the Gemini API. private static let modelResourcePrefix = "models/" - // The prefix for a tuned model resource in the Gemini API. - private static let tunedModelResourcePrefix = "tunedModels/" - /// The resource name of the model in the backend; has the format "models/model-name". let modelResourceName: String - /// The backing service responsible for sending and receiving model requests to the backend. - let generativeAIService: GenerativeAIService - - /// Configuration parameters used for the MultiModalModel. - let generationConfig: GenerationConfig? - - /// The safety settings to be used for prompts. - let safetySettings: [SafetySetting]? - - /// Configuration parameters for sending requests to the backend. - let requestOptions: RequestOptions + let internalModel: InternalGenerativeAI.GenerativeModel /// Initializes a new remote model with the given parameters. /// @@ -72,10 +59,6 @@ public final class GenerativeModel { requestOptions: RequestOptions = RequestOptions(), urlSession: URLSession) { modelResourceName = GenerativeModel.modelResourceName(name: name) - generativeAIService = GenerativeAIService(apiKey: apiKey, urlSession: urlSession) - self.generationConfig = generationConfig - self.safetySettings = safetySettings - self.requestOptions = requestOptions Logging.default.info(""" [GoogleGenerativeAI] Model \( @@ -85,6 +68,15 @@ public final class GenerativeModel { `\(Logging.enableArgumentKey, privacy: .public)` as a launch argument in Xcode. """) Logging.verbose.debug("[GoogleGenerativeAI] Verbose logging enabled.") + + internalModel = InternalGenerativeAI.GenerativeModel( + modelResourceName: modelResourceName, + apiKey: apiKey, + generationConfig: generationConfig?.toInternal(), + safetySettings: safetySettings?.toInternal(), + requestOptions: requestOptions.toInternal(), + urlSession: urlSession + ) } /// Generates content from String and/or image inputs, given to the model as a prompt, that are @@ -114,36 +106,13 @@ public final class GenerativeModel { /// - Throws: A ``GenerateContentError`` if the request failed. public func generateContent(_ content: @autoclosure () throws -> [ModelContent]) async throws -> GenerateContentResponse { - let response: GenerateContentResponse do { - let generateContentRequest = try GenerateContentRequest(model: modelResourceName, - contents: content().toInternal(), - generationConfig: generationConfig? - .toInternal(), - safetySettings: safetySettings? - .toInternal(), - isStreaming: false, - options: requestOptions.toInternal()) - response = try GenerateContentResponse(internalResponse: await generativeAIService - .loadRequest(request: generateContentRequest)) + let evaluatedContent = try content() + return try await GenerateContentResponse(internalResponse: internalModel + .generateContent(evaluatedContent.toInternal())) } catch { - if let imageError = error as? ImageConversionError { - throw GenerateContentError.promptImageContentError(underlying: imageError) - } throw GenerativeModel.generateContentError(from: error) } - - // Check the prompt feedback to see if the prompt was blocked. - if response.promptFeedback?.blockReason != nil { - throw GenerateContentError.promptBlocked(response: response) - } - - // Check to see if an error should be thrown for stop reason. - if let reason = response.candidates.first?.finishReason, reason != .stop { - throw GenerateContentError.responseStoppedEarly(reason: reason, response: response) - } - - return response } /// Generates content from String and/or image inputs, given to the model as a prompt, that are @@ -180,54 +149,20 @@ public final class GenerativeModel { evaluatedContent = try content() } catch let underlying { return AsyncThrowingStream { continuation in - let error: Error - if let contentError = underlying as? ImageConversionError { - error = GenerateContentError.promptImageContentError(underlying: contentError) - } else { - error = GenerateContentError.internalError(underlying: underlying) - } + let error = GenerativeModel.generateContentError(from: underlying) continuation.finish(throwing: error) } } - let generateContentRequest = GenerateContentRequest(model: modelResourceName, - contents: evaluatedContent.toInternal(), - generationConfig: generationConfig? - .toInternal(), - safetySettings: safetySettings? - .toInternal(), - isStreaming: true, - options: requestOptions.toInternal()) - - var responseIterator = generativeAIService.loadRequestStream(request: generateContentRequest) - .makeAsyncIterator() + var responseIterator = internalModel + .generateContentStream(evaluatedContent.toInternal()).makeAsyncIterator() return AsyncThrowingStream { - let response: GenerateContentResponse? do { - response = try await responseIterator.next() + return try await responseIterator.next() .flatMap { GenerateContentResponse(internalResponse: $0) } } catch { throw GenerativeModel.generateContentError(from: error) } - - // The responseIterator will return `nil` when it's done. - guard let response = response else { - // This is the end of the stream! Signal it by sending `nil`. - return nil - } - - // Check the prompt feedback to see if the prompt was blocked. - if response.promptFeedback?.blockReason != nil { - throw GenerateContentError.promptBlocked(response: response) - } - - // If the stream ended early unexpectedly, throw an error. - if let finishReason = response.candidates.first?.finishReason, finishReason != .stop { - throw GenerateContentError.responseStoppedEarly(reason: finishReason, response: response) - } else { - // Response was valid content, pass it along and continue. - return response - } } } @@ -266,13 +201,13 @@ public final class GenerativeModel { public func countTokens(_ content: @autoclosure () throws -> [ModelContent]) async throws -> CountTokensResponse { do { - let countTokensRequest = try CountTokensRequest( - model: modelResourceName, - contents: content().toInternal(), - options: requestOptions.toInternal() - ) - let internalResponse = try await generativeAIService.loadRequest(request: countTokensRequest) + let internalResponse = try await internalModel.countTokens(content().toInternal()) return CountTokensResponse(internalResponse: internalResponse) + } catch let error as InternalGenerativeAI.CountTokensError { + switch error { + case let .internalError(underlying: underlying): + throw CountTokensError.internalError(underlying: underlying) + } } catch { throw CountTokensError.internalError(underlying: error) } @@ -293,6 +228,32 @@ public final class GenerativeModel { private static func generateContentError(from error: Error) -> GenerateContentError { if let error = error as? GenerateContentError { return error + } else if let error = error as? InternalGenerativeAI.GenerateContentError { + switch error { + case let .internalError(underlying: underlying): + return GenerateContentError.internalError(underlying: underlying) + case let .promptBlocked(response: response): + return GenerateContentError + .promptBlocked(response: GenerateContentResponse(internalResponse: response)) + case let .responseStoppedEarly(reason: reason, response: response): + return GenerateContentError.responseStoppedEarly( + reason: FinishReason(internalReason: reason), + response: GenerateContentResponse(internalResponse: response) + ) + } + } else if let error = error as? InternalGenerativeAI.InvalidCandidateError { + switch error { + case let .emptyContent(underlyingError): + return GenerateContentError + .internalError(underlying: InvalidCandidateError + .emptyContent(underlyingError: underlyingError)) + case let .malformedContent(underlyingError): + return GenerateContentError + .internalError(underlying: InvalidCandidateError + .malformedContent(underlyingError: underlyingError)) + } + } else if let error = error as? ImageConversionError { + return GenerateContentError.promptImageContentError(underlying: error) } else if let error = error as? RPCError, error.isInvalidAPIKeyError() { return GenerateContentError.invalidAPIKey } else if let error = error as? RPCError, error.isUnsupportedUserLocationError() { diff --git a/Sources/Internal/Errors.swift b/Sources/Internal/Errors.swift index 3e56366..479e58d 100644 --- a/Sources/Internal/Errors.swift +++ b/Sources/Internal/Errors.swift @@ -171,7 +171,7 @@ public enum RPCStatus: String, Decodable { case dataLoss = "DATA_LOSS" } -enum InvalidCandidateError: Error { +public enum InvalidCandidateError: Error { case emptyContent(underlyingError: Error) case malformedContent(underlyingError: Error) } diff --git a/Sources/Internal/GenerateContentError.swift b/Sources/Internal/GenerateContentError.swift new file mode 100644 index 0000000..6a17bbd --- /dev/null +++ b/Sources/Internal/GenerateContentError.swift @@ -0,0 +1,28 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +/// Errors that occur when generating content from a model. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +public enum GenerateContentError: Error { + /// An internal error occurred. See the underlying error for more context. + case internalError(underlying: Error) + + /// A prompt was blocked. See the response's `promptFeedback.blockReason` for more information. + case promptBlocked(response: GenerateContentResponse) + + /// A response didn't fully complete. See the `FinishReason` for more information. + case responseStoppedEarly(reason: FinishReason, response: GenerateContentResponse) +} diff --git a/Sources/Internal/GenerativeModel.swift b/Sources/Internal/GenerativeModel.swift new file mode 100644 index 0000000..351df21 --- /dev/null +++ b/Sources/Internal/GenerativeModel.swift @@ -0,0 +1,175 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +/// A type that represents a remote multimodal model (like Gemini), with the ability to generate +/// content based on various input types. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +public final class GenerativeModel { + /// The resource name of the model in the backend; has the format "models/model-name". + let modelResourceName: String + + /// The backing service responsible for sending and receiving model requests to the backend. + let generativeAIService: GenerativeAIService + + /// Configuration parameters used for the MultiModalModel. + let generationConfig: GenerationConfig? + + /// The safety settings to be used for prompts. + let safetySettings: [SafetySetting]? + + /// Configuration parameters for sending requests to the backend. + let requestOptions: RequestOptions + + /// The designated initializer for this class. + public init(modelResourceName: String, + apiKey: String, + generationConfig: GenerationConfig? = nil, + safetySettings: [SafetySetting]? = nil, + requestOptions: RequestOptions, + urlSession: URLSession) { + self.modelResourceName = modelResourceName + generativeAIService = GenerativeAIService(apiKey: apiKey, urlSession: urlSession) + self.generationConfig = generationConfig + self.safetySettings = safetySettings + self.requestOptions = requestOptions + + Logging.default.info(""" + [GoogleGenerativeAI] Model \( + modelResourceName, + privacy: .public + ) initialized. To enable additional logging, add \ + `\(Logging.enableArgumentKey, privacy: .public)` as a launch argument in Xcode. + """) + Logging.verbose.debug("[GoogleGenerativeAI] Verbose logging enabled.") + } + + /// Generates new content from input content given to the model as a prompt. + /// + /// - Parameter content: The input(s) given to the model as a prompt. + /// - Returns: The generated content response from the model. + /// - Throws: A ``GenerateContentError`` if the request failed. + public func generateContent(_ content: [ModelContent]) async throws + -> GenerateContentResponse { + let generateContentRequest = GenerateContentRequest(model: modelResourceName, + contents: content, + generationConfig: generationConfig, + safetySettings: safetySettings, + isStreaming: false, + options: requestOptions) + let response = try await generativeAIService.loadRequest(request: generateContentRequest) + + // Check the prompt feedback to see if the prompt was blocked. + if response.promptFeedback?.blockReason != nil { + throw InternalGenerativeAI.GenerateContentError.promptBlocked(response: response) + } + + // Check to see if an error should be thrown for stop reason. + if let reason = response.candidates.first?.finishReason, reason != .stop { + throw InternalGenerativeAI.GenerateContentError.responseStoppedEarly( + reason: reason, + response: response + ) + } + + return response + } + + /// Generates new content from input content given to the model as a prompt. + /// + /// - Parameter content: The input(s) given to the model as a prompt. + /// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError`` + /// error if an error occurred. + @available(macOS 12.0, *) + public func generateContentStream(_ content: [ModelContent]) + -> AsyncThrowingStream { + let generateContentRequest = GenerateContentRequest(model: modelResourceName, + contents: content, + generationConfig: generationConfig, + safetySettings: safetySettings, + isStreaming: true, + options: requestOptions) + + var responseIterator = generativeAIService.loadRequestStream(request: generateContentRequest) + .makeAsyncIterator() + return AsyncThrowingStream { + let response = try await responseIterator.next() + + // The responseIterator will return `nil` when it's done. + guard let response = response else { + // This is the end of the stream! Signal it by sending `nil`. + return nil + } + + // Check the prompt feedback to see if the prompt was blocked. + if response.promptFeedback?.blockReason != nil { + throw InternalGenerativeAI.GenerateContentError.promptBlocked(response: response) + } + + // If the stream ended early unexpectedly, throw an error. + if let finishReason = response.candidates.first?.finishReason, finishReason != .stop { + throw InternalGenerativeAI.GenerateContentError.responseStoppedEarly( + reason: finishReason, + response: response + ) + } else { + // Response was valid content, pass it along and continue. + return response + } + } + } + + /// Creates a new chat conversation using this model with the provided history. + // TODO(andrewheard): Implement me + // public func startChat(history: [ModelContent] = []) -> Chat { + // return Chat(model: self, history: history) + // } + + /// Runs the model's tokenizer on the input content and returns the token count. + /// + /// - Parameter content: The input given to the model as a prompt. + /// - Returns: The results of running the model's tokenizer on the input; contains + /// ``CountTokensResponse/totalTokens``. + /// - Throws: A ``CountTokensError`` if the tokenization request failed or the input content was + /// invalid. + public func countTokens(_ content: [ModelContent]) async throws + -> CountTokensResponse { + do { + let countTokensRequest = CountTokensRequest( + model: modelResourceName, + contents: content, + options: requestOptions + ) + return try await generativeAIService.loadRequest(request: countTokensRequest) + } catch { + throw CountTokensError.internalError(underlying: error) + } + } + +// /// Returns a model resource name of the form "models/model-name" based on `name`. +// private static func modelResourceName(name: String) -> String { +// if name.contains("/") { +// return name +// } else { +// return modelResourcePrefix + name +// } +// } +} + +/// See ``GenerativeModel/countTokens(_:)-9spwl``. +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +public enum CountTokensError: Error { + case internalError(underlying: Error) +} diff --git a/Tests/GoogleAITests/GenerativeModelTests.swift b/Tests/GoogleAITests/GenerativeModelTests.swift index 6b202bc..2aecbc5 100644 --- a/Tests/GoogleAITests/GenerativeModelTests.swift +++ b/Tests/GoogleAITests/GenerativeModelTests.swift @@ -28,7 +28,7 @@ final class GenerativeModelTests: XCTestCase { ] var urlSession: URLSession! - var model: GenerativeModel! + var model: GoogleGenerativeAI.GenerativeModel! override func setUp() async throws { let configuration = URLSessionConfiguration.default @@ -183,7 +183,7 @@ final class GenerativeModelTests: XCTestCase { do { _ = try await model.generateContent(testPrompt) XCTFail("Should throw GenerateContentError.internalError; no error thrown.") - } catch GenerateContentError.invalidAPIKey { + } catch GoogleGenerativeAI.GenerateContentError.invalidAPIKey { // Do nothing, catching a GenerateContentError.invalidAPIKey error is expected. } catch { XCTFail("Should throw GenerateContentError.invalidAPIKey; error thrown: \(error)") @@ -200,10 +200,9 @@ final class GenerativeModelTests: XCTestCase { do { _ = try await model.generateContent(testPrompt) XCTFail("Should throw GenerateContentError.internalError; no error thrown.") - } catch let GenerateContentError - .internalError(underlying: invalidCandidateError as InternalGenerativeAI + } catch let GoogleGenerativeAI.GenerateContentError + .internalError(underlying: invalidCandidateError as GoogleGenerativeAI .InvalidCandidateError) { - // TODO(andrewheard): Convert to GoogleGenerativeAI.InvalidCandidateError guard case let .emptyContent(decodingError) = invalidCandidateError else { XCTFail("Not an InvalidCandidateError.emptyContent error: \(invalidCandidateError)") return @@ -225,7 +224,7 @@ final class GenerativeModelTests: XCTestCase { do { _ = try await model.generateContent(testPrompt) XCTFail("Should throw") - } catch let GenerateContentError.responseStoppedEarly(reason, response) { + } catch let GoogleGenerativeAI.GenerateContentError.responseStoppedEarly(reason, response) { XCTAssertEqual(reason, .safety) XCTAssertEqual(response.text, "No") } catch { @@ -243,7 +242,7 @@ final class GenerativeModelTests: XCTestCase { do { _ = try await model.generateContent(testPrompt) XCTFail("Should throw") - } catch let GenerateContentError.responseStoppedEarly(reason, response) { + } catch let GoogleGenerativeAI.GenerateContentError.responseStoppedEarly(reason, response) { XCTAssertEqual(reason, .safety) XCTAssertNil(response.text) } catch { @@ -263,7 +262,8 @@ final class GenerativeModelTests: XCTestCase { do { _ = try await model.generateContent(testPrompt) XCTFail("Should throw GenerateContentError.internalError; no error thrown.") - } catch let GenerateContentError.internalError(underlying: rpcError as RPCError) { + } catch let GoogleGenerativeAI.GenerateContentError + .internalError(underlying: rpcError as RPCError) { XCTAssertEqual(rpcError.status, .invalidArgument) XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode) XCTAssertEqual(rpcError.message, "Request contains an invalid argument.") @@ -282,7 +282,7 @@ final class GenerativeModelTests: XCTestCase { do { _ = try await model.generateContent(testPrompt) XCTFail("Should throw") - } catch let GenerateContentError.promptBlocked(response) { + } catch let GoogleGenerativeAI.GenerateContentError.promptBlocked(response) { XCTAssertNil(response.text) } catch { XCTFail("Should throw a promptBlocked") @@ -299,7 +299,7 @@ final class GenerativeModelTests: XCTestCase { do { _ = try await model.generateContent(testPrompt) XCTFail("Should throw") - } catch let GenerateContentError.responseStoppedEarly(reason, response) { + } catch let GoogleGenerativeAI.GenerateContentError.responseStoppedEarly(reason, response) { XCTAssertEqual(reason, .unknown) XCTAssertEqual(response.text, "Some text") } catch { @@ -317,7 +317,7 @@ final class GenerativeModelTests: XCTestCase { do { _ = try await model.generateContent(testPrompt) XCTFail("Should throw") - } catch let GenerateContentError.promptBlocked(response) { + } catch let GoogleGenerativeAI.GenerateContentError.promptBlocked(response) { let promptFeedback = try XCTUnwrap(response.promptFeedback) XCTAssertEqual(promptFeedback.blockReason, .unknown) } catch { @@ -337,7 +337,8 @@ final class GenerativeModelTests: XCTestCase { do { _ = try await model.generateContent(testPrompt) XCTFail("Should throw GenerateContentError.internalError; no error thrown.") - } catch let GenerateContentError.internalError(underlying: rpcError as RPCError) { + } catch let GoogleGenerativeAI.GenerateContentError + .internalError(underlying: rpcError as RPCError) { XCTAssertEqual(rpcError.status, .notFound) XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode) XCTAssertTrue(rpcError.message.hasPrefix("models/unknown is not found")) @@ -357,7 +358,7 @@ final class GenerativeModelTests: XCTestCase { do { _ = try await model.generateContent(testPrompt) XCTFail("Should throw GenerateContentError.unsupportedUserLocation; no error thrown.") - } catch GenerateContentError.unsupportedUserLocation { + } catch GoogleGenerativeAI.GenerateContentError.unsupportedUserLocation { return } @@ -377,7 +378,8 @@ final class GenerativeModelTests: XCTestCase { XCTAssertNil(content) XCTAssertNotNil(responseError) - let generateContentError = try XCTUnwrap(responseError as? GenerateContentError) + let generateContentError = try XCTUnwrap(responseError as? GoogleGenerativeAI + .GenerateContentError) guard case let .internalError(underlyingError) = generateContentError else { XCTFail("Not an internal error: \(generateContentError)") return @@ -401,7 +403,8 @@ final class GenerativeModelTests: XCTestCase { XCTAssertNil(content) XCTAssertNotNil(responseError) - let generateContentError = try XCTUnwrap(responseError as? GenerateContentError) + let generateContentError = try XCTUnwrap(responseError as? GoogleGenerativeAI + .GenerateContentError) guard case let .internalError(underlyingError) = generateContentError else { XCTFail("Not an internal error: \(generateContentError)") return @@ -431,13 +434,13 @@ final class GenerativeModelTests: XCTestCase { XCTAssertNil(content) XCTAssertNotNil(responseError) - let generateContentError = try XCTUnwrap(responseError as? GenerateContentError) + let generateContentError = try XCTUnwrap(responseError as? GoogleGenerativeAI + .GenerateContentError) guard case let .internalError(underlyingError) = generateContentError else { XCTFail("Not an internal error: \(generateContentError)") return } - // TODO(andrewheard): Convert to GoogleGenerativeAI.InvalidCandidateError - let invalidCandidateError = try XCTUnwrap(underlyingError as? InternalGenerativeAI + let invalidCandidateError = try XCTUnwrap(underlyingError as? GoogleGenerativeAI .InvalidCandidateError) guard case let .malformedContent(malformedContentUnderlyingError) = invalidCandidateError else { XCTFail("Not a malformed content error: \(invalidCandidateError)") @@ -496,7 +499,7 @@ final class GenerativeModelTests: XCTestCase { for try await _ in stream { XCTFail("No content is there, this shouldn't happen.") } - } catch GenerateContentError.invalidAPIKey { + } catch GoogleGenerativeAI.GenerateContentError.invalidAPIKey { // invalidAPIKey error is as expected, nothing else to check. return } @@ -516,8 +519,8 @@ final class GenerativeModelTests: XCTestCase { for try await _ in stream { XCTFail("No content is there, this shouldn't happen.") } - } catch GenerateContentError.internalError(_ as InternalGenerativeAI.InvalidCandidateError) { - // TODO(andrewheard): Convert to GoogleGenerativeAI.InvalidCandidateError + } catch GoogleGenerativeAI.GenerateContentError + .internalError(_ as GoogleGenerativeAI.InvalidCandidateError) { // Underlying error is as expected, nothing else to check. return } @@ -537,7 +540,7 @@ final class GenerativeModelTests: XCTestCase { for try await _ in stream { XCTFail("Content shouldn't be shown, this shouldn't happen.") } - } catch let GenerateContentError.responseStoppedEarly(reason, _) { + } catch let GoogleGenerativeAI.GenerateContentError.responseStoppedEarly(reason, _) { XCTAssertEqual(reason, .safety) return } @@ -557,7 +560,7 @@ final class GenerativeModelTests: XCTestCase { for try await _ in stream { XCTFail("Content shouldn't be shown, this shouldn't happen.") } - } catch let GenerateContentError.promptBlocked(response) { + } catch let GoogleGenerativeAI.GenerateContentError.promptBlocked(response) { XCTAssertEqual(response.promptFeedback?.blockReason, .safety) return } @@ -577,7 +580,7 @@ final class GenerativeModelTests: XCTestCase { for try await content in stream { XCTAssertNotNil(content.text) } - } catch let GenerateContentError.responseStoppedEarly(reason, _) { + } catch let GoogleGenerativeAI.GenerateContentError.responseStoppedEarly(reason, _) { XCTAssertEqual(reason, .unknown) return } @@ -677,7 +680,7 @@ final class GenerativeModelTests: XCTestCase { XCTAssertNotNil(content.text) responseCount += 1 } - } catch let GenerateContentError.internalError(rpcError as RPCError) { + } catch let GoogleGenerativeAI.GenerateContentError.internalError(rpcError as RPCError) { XCTAssertEqual(rpcError.httpResponseCode, 499) XCTAssertEqual(rpcError.status, .cancelled) @@ -697,7 +700,7 @@ final class GenerativeModelTests: XCTestCase { for try await content in stream { XCTFail("Unexpected content in stream: \(content)") } - } catch let GenerateContentError.internalError(underlying) { + } catch let GoogleGenerativeAI.GenerateContentError.internalError(underlying) { XCTAssertEqual(underlying.localizedDescription, "Response was not an HTTP response.") return } @@ -717,7 +720,7 @@ final class GenerativeModelTests: XCTestCase { for try await content in stream { XCTFail("Unexpected content in stream: \(content)") } - } catch let GenerateContentError.internalError(underlying as DecodingError) { + } catch let GoogleGenerativeAI.GenerateContentError.internalError(underlying as DecodingError) { guard case let .dataCorrupted(context) = underlying else { XCTFail("Not a data corrupted error: \(underlying)") return @@ -741,9 +744,8 @@ final class GenerativeModelTests: XCTestCase { for try await content in stream { XCTFail("Unexpected content in stream: \(content)") } - } catch let GenerateContentError - .internalError(underlyingError as InternalGenerativeAI.InvalidCandidateError) { - // TODO(andrewheard): Convert to GoogleGenerativeAI.InvalidCandidateError + } catch let GoogleGenerativeAI.GenerateContentError + .internalError(underlyingError as GoogleGenerativeAI.InvalidCandidateError) { guard case let .malformedContent(contentError) = underlyingError else { XCTFail("Not a malformed content error: \(underlyingError)") return @@ -769,7 +771,7 @@ final class GenerativeModelTests: XCTestCase { for try await content in stream { XCTFail("Unexpected content in stream: \(content)") } - } catch GenerateContentError.unsupportedUserLocation { + } catch GoogleGenerativeAI.GenerateContentError.unsupportedUserLocation { return } @@ -823,7 +825,8 @@ final class GenerativeModelTests: XCTestCase { do { _ = try await model.countTokens("Why is the sky blue?") XCTFail("Request should not have succeeded.") - } catch let CountTokensError.internalError(rpcError as RPCError) { + } catch let GoogleGenerativeAI.CountTokensError + .internalError(underlying: rpcError as RPCError) { XCTAssertEqual(rpcError.httpResponseCode, 404) XCTAssertEqual(rpcError.status, .notFound) XCTAssert(rpcError.message.hasPrefix("models/test-model-name is not found"))