Skip to content

Commit

Permalink
import struct InternalGenerativeAI.RPCError
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard committed Mar 13, 2024
1 parent bd5353c commit eeee2a3
Showing 1 changed file with 30 additions and 34 deletions.
64 changes: 30 additions & 34 deletions Tests/GoogleAITests/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.

@testable import InternalGenerativeAI
import XCTest

@testable import GoogleGenerativeAI

import struct InternalGenerativeAI.RPCError

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, *)
final class GenerativeModelTests: XCTestCase {
let testPrompt = "What sorts of questions can I ask you?"
let safetyRatingsNegligible: [GoogleGenerativeAI.SafetyRating] = [
let safetyRatingsNegligible: [SafetyRating] = [
.init(category: .sexuallyExplicit, probability: .negligible),
.init(category: .hateSpeech, probability: .negligible),
.init(category: .harassment, probability: .negligible),
.init(category: .dangerousContent, probability: .negligible),
]

var urlSession: URLSession!
var model: GoogleGenerativeAI.GenerativeModel!
var model: GenerativeModel!

override func setUp() async throws {
let configuration = URLSessionConfiguration.default
Expand Down Expand Up @@ -137,7 +138,7 @@ final class GenerativeModelTests: XCTestCase {
}

func testGenerateContent_success_unknownEnum_safetyRatings() async throws {
let expectedSafetyRatings: [GoogleGenerativeAI.SafetyRating] = [
let expectedSafetyRatings: [SafetyRating] = [
.init(category: .harassment, probability: .medium),
.init(category: .dangerousContent, probability: .unknown),
.init(category: .unknown, probability: .high),
Expand Down Expand Up @@ -183,7 +184,7 @@ final class GenerativeModelTests: XCTestCase {
do {
_ = try await model.generateContent(testPrompt)
XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
} catch GoogleGenerativeAI.GenerateContentError.invalidAPIKey {
} catch GenerateContentError.invalidAPIKey {
// Do nothing, catching a GenerateContentError.invalidAPIKey error is expected.
} catch {
XCTFail("Should throw GenerateContentError.invalidAPIKey; error thrown: \(error)")
Expand All @@ -200,7 +201,7 @@ final class GenerativeModelTests: XCTestCase {
do {
_ = try await model.generateContent(testPrompt)
XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
} catch let GoogleGenerativeAI.GenerateContentError
} catch let GenerateContentError
.internalError(underlying: invalidCandidateError as GoogleGenerativeAI
.InvalidCandidateError) {
guard case let .emptyContent(decodingError) = invalidCandidateError else {
Expand All @@ -224,7 +225,7 @@ final class GenerativeModelTests: XCTestCase {
do {
_ = try await model.generateContent(testPrompt)
XCTFail("Should throw")
} catch let GoogleGenerativeAI.GenerateContentError.responseStoppedEarly(reason, response) {
} catch let GenerateContentError.responseStoppedEarly(reason, response) {
XCTAssertEqual(reason, .safety)
XCTAssertEqual(response.text, "No")
} catch {
Expand All @@ -242,7 +243,7 @@ final class GenerativeModelTests: XCTestCase {
do {
_ = try await model.generateContent(testPrompt)
XCTFail("Should throw")
} catch let GoogleGenerativeAI.GenerateContentError.responseStoppedEarly(reason, response) {
} catch let GenerateContentError.responseStoppedEarly(reason, response) {
XCTAssertEqual(reason, .safety)
XCTAssertNil(response.text)
} catch {
Expand All @@ -262,8 +263,7 @@ final class GenerativeModelTests: XCTestCase {
do {
_ = try await model.generateContent(testPrompt)
XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
} catch let GoogleGenerativeAI.GenerateContentError
.internalError(underlying: rpcError as RPCError) {
} catch let GenerateContentError.internalError(underlying: rpcError as RPCError) {
XCTAssertEqual(rpcError.status, .invalidArgument)
XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
XCTAssertEqual(rpcError.message, "Request contains an invalid argument.")
Expand All @@ -282,7 +282,7 @@ final class GenerativeModelTests: XCTestCase {
do {
_ = try await model.generateContent(testPrompt)
XCTFail("Should throw")
} catch let GoogleGenerativeAI.GenerateContentError.promptBlocked(response) {
} catch let GenerateContentError.promptBlocked(response) {
XCTAssertNil(response.text)
} catch {
XCTFail("Should throw a promptBlocked")
Expand All @@ -299,7 +299,7 @@ final class GenerativeModelTests: XCTestCase {
do {
_ = try await model.generateContent(testPrompt)
XCTFail("Should throw")
} catch let GoogleGenerativeAI.GenerateContentError.responseStoppedEarly(reason, response) {
} catch let GenerateContentError.responseStoppedEarly(reason, response) {
XCTAssertEqual(reason, .unknown)
XCTAssertEqual(response.text, "Some text")
} catch {
Expand All @@ -317,7 +317,7 @@ final class GenerativeModelTests: XCTestCase {
do {
_ = try await model.generateContent(testPrompt)
XCTFail("Should throw")
} catch let GoogleGenerativeAI.GenerateContentError.promptBlocked(response) {
} catch let GenerateContentError.promptBlocked(response) {
let promptFeedback = try XCTUnwrap(response.promptFeedback)
XCTAssertEqual(promptFeedback.blockReason, .unknown)
} catch {
Expand All @@ -337,8 +337,7 @@ final class GenerativeModelTests: XCTestCase {
do {
_ = try await model.generateContent(testPrompt)
XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
} catch let GoogleGenerativeAI.GenerateContentError
.internalError(underlying: rpcError as RPCError) {
} catch let GenerateContentError.internalError(underlying: rpcError as RPCError) {
XCTAssertEqual(rpcError.status, .notFound)
XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
XCTAssertTrue(rpcError.message.hasPrefix("models/unknown is not found"))
Expand All @@ -358,7 +357,7 @@ final class GenerativeModelTests: XCTestCase {
do {
_ = try await model.generateContent(testPrompt)
XCTFail("Should throw GenerateContentError.unsupportedUserLocation; no error thrown.")
} catch GoogleGenerativeAI.GenerateContentError.unsupportedUserLocation {
} catch GenerateContentError.unsupportedUserLocation {
return
}

Expand All @@ -369,7 +368,7 @@ final class GenerativeModelTests: XCTestCase {
MockURLProtocol.requestHandler = try nonHTTPRequestHandler()

var responseError: Error?
var content: GoogleGenerativeAI.GenerateContentResponse?
var content: GenerateContentResponse?
do {
content = try await model.generateContent(testPrompt)
} catch {
Expand All @@ -394,7 +393,7 @@ final class GenerativeModelTests: XCTestCase {
)

var responseError: Error?
var content: GoogleGenerativeAI.GenerateContentResponse?
var content: GenerateContentResponse?
do {
content = try await model.generateContent(testPrompt)
} catch {
Expand Down Expand Up @@ -425,7 +424,7 @@ final class GenerativeModelTests: XCTestCase {
)

var responseError: Error?
var content: GoogleGenerativeAI.GenerateContentResponse?
var content: GenerateContentResponse?
do {
content = try await model.generateContent(testPrompt)
} catch {
Expand Down Expand Up @@ -499,7 +498,7 @@ final class GenerativeModelTests: XCTestCase {
for try await _ in stream {
XCTFail("No content is there, this shouldn't happen.")
}
} catch GoogleGenerativeAI.GenerateContentError.invalidAPIKey {
} catch GenerateContentError.invalidAPIKey {
// invalidAPIKey error is as expected, nothing else to check.
return
}
Expand All @@ -519,8 +518,7 @@ final class GenerativeModelTests: XCTestCase {
for try await _ in stream {
XCTFail("No content is there, this shouldn't happen.")
}
} catch GoogleGenerativeAI.GenerateContentError
.internalError(_ as GoogleGenerativeAI.InvalidCandidateError) {
} catch GenerateContentError.internalError(_ as InvalidCandidateError) {
// Underlying error is as expected, nothing else to check.
return
}
Expand All @@ -540,7 +538,7 @@ final class GenerativeModelTests: XCTestCase {
for try await _ in stream {
XCTFail("Content shouldn't be shown, this shouldn't happen.")
}
} catch let GoogleGenerativeAI.GenerateContentError.responseStoppedEarly(reason, _) {
} catch let GenerateContentError.responseStoppedEarly(reason, _) {
XCTAssertEqual(reason, .safety)
return
}
Expand All @@ -560,7 +558,7 @@ final class GenerativeModelTests: XCTestCase {
for try await _ in stream {
XCTFail("Content shouldn't be shown, this shouldn't happen.")
}
} catch let GoogleGenerativeAI.GenerateContentError.promptBlocked(response) {
} catch let GenerateContentError.promptBlocked(response) {
XCTAssertEqual(response.promptFeedback?.blockReason, .safety)
return
}
Expand All @@ -580,7 +578,7 @@ final class GenerativeModelTests: XCTestCase {
for try await content in stream {
XCTAssertNotNil(content.text)
}
} catch let GoogleGenerativeAI.GenerateContentError.responseStoppedEarly(reason, _) {
} catch let GenerateContentError.responseStoppedEarly(reason, _) {
XCTAssertEqual(reason, .unknown)
return
}
Expand Down Expand Up @@ -650,7 +648,7 @@ final class GenerativeModelTests: XCTestCase {
)

let stream = model.generateContentStream("Hi")
var citations: [GoogleGenerativeAI.Citation] = []
var citations: [Citation] = []
for try await content in stream {
XCTAssertNotNil(content.text)
let candidate = try XCTUnwrap(content.candidates.first)
Expand Down Expand Up @@ -680,7 +678,7 @@ final class GenerativeModelTests: XCTestCase {
XCTAssertNotNil(content.text)
responseCount += 1
}
} catch let GoogleGenerativeAI.GenerateContentError.internalError(rpcError as RPCError) {
} catch let GenerateContentError.internalError(rpcError as RPCError) {
XCTAssertEqual(rpcError.httpResponseCode, 499)
XCTAssertEqual(rpcError.status, .cancelled)

Expand All @@ -700,7 +698,7 @@ final class GenerativeModelTests: XCTestCase {
for try await content in stream {
XCTFail("Unexpected content in stream: \(content)")
}
} catch let GoogleGenerativeAI.GenerateContentError.internalError(underlying) {
} catch let GenerateContentError.internalError(underlying) {
XCTAssertEqual(underlying.localizedDescription, "Response was not an HTTP response.")
return
}
Expand All @@ -720,7 +718,7 @@ final class GenerativeModelTests: XCTestCase {
for try await content in stream {
XCTFail("Unexpected content in stream: \(content)")
}
} catch let GoogleGenerativeAI.GenerateContentError.internalError(underlying as DecodingError) {
} catch let GenerateContentError.internalError(underlying as DecodingError) {
guard case let .dataCorrupted(context) = underlying else {
XCTFail("Not a data corrupted error: \(underlying)")
return
Expand All @@ -744,8 +742,7 @@ final class GenerativeModelTests: XCTestCase {
for try await content in stream {
XCTFail("Unexpected content in stream: \(content)")
}
} catch let GoogleGenerativeAI.GenerateContentError
.internalError(underlyingError as GoogleGenerativeAI.InvalidCandidateError) {
} catch let GenerateContentError.internalError(underlyingError as InvalidCandidateError) {
guard case let .malformedContent(contentError) = underlyingError else {
XCTFail("Not a malformed content error: \(underlyingError)")
return
Expand All @@ -771,7 +768,7 @@ final class GenerativeModelTests: XCTestCase {
for try await content in stream {
XCTFail("Unexpected content in stream: \(content)")
}
} catch GoogleGenerativeAI.GenerateContentError.unsupportedUserLocation {
} catch GenerateContentError.unsupportedUserLocation {
return
}

Expand Down Expand Up @@ -825,8 +822,7 @@ final class GenerativeModelTests: XCTestCase {
do {
_ = try await model.countTokens("Why is the sky blue?")
XCTFail("Request should not have succeeded.")
} catch let GoogleGenerativeAI.CountTokensError
.internalError(underlying: rpcError as RPCError) {
} catch let CountTokensError.internalError(underlying: rpcError as RPCError) {
XCTAssertEqual(rpcError.httpResponseCode, 404)
XCTAssertEqual(rpcError.status, .notFound)
XCTAssert(rpcError.message.hasPrefix("models/test-model-name is not found"))
Expand Down

0 comments on commit eeee2a3

Please sign in to comment.