From eeee2a3e87e85f95b06813ae65923af271270fbe Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Wed, 13 Mar 2024 18:42:13 -0400 Subject: [PATCH] import struct InternalGenerativeAI.RPCError --- .../GoogleAITests/GenerativeModelTests.swift | 64 +++++++++---------- 1 file changed, 30 insertions(+), 34 deletions(-) diff --git a/Tests/GoogleAITests/GenerativeModelTests.swift b/Tests/GoogleAITests/GenerativeModelTests.swift index 2aecbc5..25a6af6 100644 --- a/Tests/GoogleAITests/GenerativeModelTests.swift +++ b/Tests/GoogleAITests/GenerativeModelTests.swift @@ -12,15 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -@testable import InternalGenerativeAI import XCTest @testable import GoogleGenerativeAI +import struct InternalGenerativeAI.RPCError + @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, *) final class GenerativeModelTests: XCTestCase { let testPrompt = "What sorts of questions can I ask you?" - let safetyRatingsNegligible: [GoogleGenerativeAI.SafetyRating] = [ + let safetyRatingsNegligible: [SafetyRating] = [ .init(category: .sexuallyExplicit, probability: .negligible), .init(category: .hateSpeech, probability: .negligible), .init(category: .harassment, probability: .negligible), @@ -28,7 +29,7 @@ final class GenerativeModelTests: XCTestCase { ] var urlSession: URLSession! - var model: GoogleGenerativeAI.GenerativeModel! + var model: GenerativeModel! override func setUp() async throws { let configuration = URLSessionConfiguration.default @@ -137,7 +138,7 @@ final class GenerativeModelTests: XCTestCase { } func testGenerateContent_success_unknownEnum_safetyRatings() async throws { - let expectedSafetyRatings: [GoogleGenerativeAI.SafetyRating] = [ + let expectedSafetyRatings: [SafetyRating] = [ .init(category: .harassment, probability: .medium), .init(category: .dangerousContent, probability: .unknown), .init(category: .unknown, probability: .high), @@ -183,7 +184,7 @@ final class GenerativeModelTests: XCTestCase { do { _ = try await model.generateContent(testPrompt) XCTFail("Should throw GenerateContentError.internalError; no error thrown.") - } catch GoogleGenerativeAI.GenerateContentError.invalidAPIKey { + } catch GenerateContentError.invalidAPIKey { // Do nothing, catching a GenerateContentError.invalidAPIKey error is expected. } catch { XCTFail("Should throw GenerateContentError.invalidAPIKey; error thrown: \(error)") @@ -200,7 +201,7 @@ final class GenerativeModelTests: XCTestCase { do { _ = try await model.generateContent(testPrompt) XCTFail("Should throw GenerateContentError.internalError; no error thrown.") - } catch let GoogleGenerativeAI.GenerateContentError + } catch let GenerateContentError .internalError(underlying: invalidCandidateError as GoogleGenerativeAI .InvalidCandidateError) { guard case let .emptyContent(decodingError) = invalidCandidateError else { @@ -224,7 +225,7 @@ final class GenerativeModelTests: XCTestCase { do { _ = try await model.generateContent(testPrompt) XCTFail("Should throw") - } catch let GoogleGenerativeAI.GenerateContentError.responseStoppedEarly(reason, response) { + } catch let GenerateContentError.responseStoppedEarly(reason, response) { XCTAssertEqual(reason, .safety) XCTAssertEqual(response.text, "No") } catch { @@ -242,7 +243,7 @@ final class GenerativeModelTests: XCTestCase { do { _ = try await model.generateContent(testPrompt) XCTFail("Should throw") - } catch let GoogleGenerativeAI.GenerateContentError.responseStoppedEarly(reason, response) { + } catch let GenerateContentError.responseStoppedEarly(reason, response) { XCTAssertEqual(reason, .safety) XCTAssertNil(response.text) } catch { @@ -262,8 +263,7 @@ final class GenerativeModelTests: XCTestCase { do { _ = try await model.generateContent(testPrompt) XCTFail("Should throw GenerateContentError.internalError; no error thrown.") - } catch let GoogleGenerativeAI.GenerateContentError - .internalError(underlying: rpcError as RPCError) { + } catch let GenerateContentError.internalError(underlying: rpcError as RPCError) { XCTAssertEqual(rpcError.status, .invalidArgument) XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode) XCTAssertEqual(rpcError.message, "Request contains an invalid argument.") @@ -282,7 +282,7 @@ final class GenerativeModelTests: XCTestCase { do { _ = try await model.generateContent(testPrompt) XCTFail("Should throw") - } catch let GoogleGenerativeAI.GenerateContentError.promptBlocked(response) { + } catch let GenerateContentError.promptBlocked(response) { XCTAssertNil(response.text) } catch { XCTFail("Should throw a promptBlocked") @@ -299,7 +299,7 @@ final class GenerativeModelTests: XCTestCase { do { _ = try await model.generateContent(testPrompt) XCTFail("Should throw") - } catch let GoogleGenerativeAI.GenerateContentError.responseStoppedEarly(reason, response) { + } catch let GenerateContentError.responseStoppedEarly(reason, response) { XCTAssertEqual(reason, .unknown) XCTAssertEqual(response.text, "Some text") } catch { @@ -317,7 +317,7 @@ final class GenerativeModelTests: XCTestCase { do { _ = try await model.generateContent(testPrompt) XCTFail("Should throw") - } catch let GoogleGenerativeAI.GenerateContentError.promptBlocked(response) { + } catch let GenerateContentError.promptBlocked(response) { let promptFeedback = try XCTUnwrap(response.promptFeedback) XCTAssertEqual(promptFeedback.blockReason, .unknown) } catch { @@ -337,8 +337,7 @@ final class GenerativeModelTests: XCTestCase { do { _ = try await model.generateContent(testPrompt) XCTFail("Should throw GenerateContentError.internalError; no error thrown.") - } catch let GoogleGenerativeAI.GenerateContentError - .internalError(underlying: rpcError as RPCError) { + } catch let GenerateContentError.internalError(underlying: rpcError as RPCError) { XCTAssertEqual(rpcError.status, .notFound) XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode) XCTAssertTrue(rpcError.message.hasPrefix("models/unknown is not found")) @@ -358,7 +357,7 @@ final class GenerativeModelTests: XCTestCase { do { _ = try await model.generateContent(testPrompt) XCTFail("Should throw GenerateContentError.unsupportedUserLocation; no error thrown.") - } catch GoogleGenerativeAI.GenerateContentError.unsupportedUserLocation { + } catch GenerateContentError.unsupportedUserLocation { return } @@ -369,7 +368,7 @@ final class GenerativeModelTests: XCTestCase { MockURLProtocol.requestHandler = try nonHTTPRequestHandler() var responseError: Error? - var content: GoogleGenerativeAI.GenerateContentResponse? + var content: GenerateContentResponse? do { content = try await model.generateContent(testPrompt) } catch { @@ -394,7 +393,7 @@ final class GenerativeModelTests: XCTestCase { ) var responseError: Error? - var content: GoogleGenerativeAI.GenerateContentResponse? + var content: GenerateContentResponse? do { content = try await model.generateContent(testPrompt) } catch { @@ -425,7 +424,7 @@ final class GenerativeModelTests: XCTestCase { ) var responseError: Error? - var content: GoogleGenerativeAI.GenerateContentResponse? + var content: GenerateContentResponse? do { content = try await model.generateContent(testPrompt) } catch { @@ -499,7 +498,7 @@ final class GenerativeModelTests: XCTestCase { for try await _ in stream { XCTFail("No content is there, this shouldn't happen.") } - } catch GoogleGenerativeAI.GenerateContentError.invalidAPIKey { + } catch GenerateContentError.invalidAPIKey { // invalidAPIKey error is as expected, nothing else to check. return } @@ -519,8 +518,7 @@ final class GenerativeModelTests: XCTestCase { for try await _ in stream { XCTFail("No content is there, this shouldn't happen.") } - } catch GoogleGenerativeAI.GenerateContentError - .internalError(_ as GoogleGenerativeAI.InvalidCandidateError) { + } catch GenerateContentError.internalError(_ as InvalidCandidateError) { // Underlying error is as expected, nothing else to check. return } @@ -540,7 +538,7 @@ final class GenerativeModelTests: XCTestCase { for try await _ in stream { XCTFail("Content shouldn't be shown, this shouldn't happen.") } - } catch let GoogleGenerativeAI.GenerateContentError.responseStoppedEarly(reason, _) { + } catch let GenerateContentError.responseStoppedEarly(reason, _) { XCTAssertEqual(reason, .safety) return } @@ -560,7 +558,7 @@ final class GenerativeModelTests: XCTestCase { for try await _ in stream { XCTFail("Content shouldn't be shown, this shouldn't happen.") } - } catch let GoogleGenerativeAI.GenerateContentError.promptBlocked(response) { + } catch let GenerateContentError.promptBlocked(response) { XCTAssertEqual(response.promptFeedback?.blockReason, .safety) return } @@ -580,7 +578,7 @@ final class GenerativeModelTests: XCTestCase { for try await content in stream { XCTAssertNotNil(content.text) } - } catch let GoogleGenerativeAI.GenerateContentError.responseStoppedEarly(reason, _) { + } catch let GenerateContentError.responseStoppedEarly(reason, _) { XCTAssertEqual(reason, .unknown) return } @@ -650,7 +648,7 @@ final class GenerativeModelTests: XCTestCase { ) let stream = model.generateContentStream("Hi") - var citations: [GoogleGenerativeAI.Citation] = [] + var citations: [Citation] = [] for try await content in stream { XCTAssertNotNil(content.text) let candidate = try XCTUnwrap(content.candidates.first) @@ -680,7 +678,7 @@ final class GenerativeModelTests: XCTestCase { XCTAssertNotNil(content.text) responseCount += 1 } - } catch let GoogleGenerativeAI.GenerateContentError.internalError(rpcError as RPCError) { + } catch let GenerateContentError.internalError(rpcError as RPCError) { XCTAssertEqual(rpcError.httpResponseCode, 499) XCTAssertEqual(rpcError.status, .cancelled) @@ -700,7 +698,7 @@ final class GenerativeModelTests: XCTestCase { for try await content in stream { XCTFail("Unexpected content in stream: \(content)") } - } catch let GoogleGenerativeAI.GenerateContentError.internalError(underlying) { + } catch let GenerateContentError.internalError(underlying) { XCTAssertEqual(underlying.localizedDescription, "Response was not an HTTP response.") return } @@ -720,7 +718,7 @@ final class GenerativeModelTests: XCTestCase { for try await content in stream { XCTFail("Unexpected content in stream: \(content)") } - } catch let GoogleGenerativeAI.GenerateContentError.internalError(underlying as DecodingError) { + } catch let GenerateContentError.internalError(underlying as DecodingError) { guard case let .dataCorrupted(context) = underlying else { XCTFail("Not a data corrupted error: \(underlying)") return @@ -744,8 +742,7 @@ final class GenerativeModelTests: XCTestCase { for try await content in stream { XCTFail("Unexpected content in stream: \(content)") } - } catch let GoogleGenerativeAI.GenerateContentError - .internalError(underlyingError as GoogleGenerativeAI.InvalidCandidateError) { + } catch let GenerateContentError.internalError(underlyingError as InvalidCandidateError) { guard case let .malformedContent(contentError) = underlyingError else { XCTFail("Not a malformed content error: \(underlyingError)") return @@ -771,7 +768,7 @@ final class GenerativeModelTests: XCTestCase { for try await content in stream { XCTFail("Unexpected content in stream: \(content)") } - } catch GoogleGenerativeAI.GenerateContentError.unsupportedUserLocation { + } catch GenerateContentError.unsupportedUserLocation { return } @@ -825,8 +822,7 @@ final class GenerativeModelTests: XCTestCase { do { _ = try await model.countTokens("Why is the sky blue?") XCTFail("Request should not have succeeded.") - } catch let GoogleGenerativeAI.CountTokensError - .internalError(underlying: rpcError as RPCError) { + } catch let CountTokensError.internalError(underlying: rpcError as RPCError) { XCTAssertEqual(rpcError.httpResponseCode, 404) XCTAssertEqual(rpcError.status, .notFound) XCTAssert(rpcError.message.hasPrefix("models/test-model-name is not found"))