Skip to content

Commit

Permalink
Make ErrorDetails an enum with associated values BadRequest/ErrorInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard committed Feb 9, 2024
1 parent 03b5dcf commit 8e613ec
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 101 deletions.
250 changes: 153 additions & 97 deletions Sources/GoogleAI/Errors.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,44 +15,143 @@
import Foundation

struct RPCError: Error {
let httpResponseCode: Int
let message: String
let status: RPCStatus
let details: [ErrorDetails]
enum ErrorDetails {
case badRequest(BadRequest)
case errorInfo(ErrorInfo)
case unknown(String)

struct BadRequest {
static let type = "type.googleapis.com/google.rpc.BadRequest"

struct FieldViolation: Decodable {
let field: String?
let description: String?
}

let type: String
let fieldViolations: [FieldViolation]
}

struct ErrorInfo {
static let type = "type.googleapis.com/google.rpc.ErrorInfo"

let type: String
let reason: String?
let domain: String?
}
}

enum Status: String, Decodable {
// Not an error; returned on success.
case ok = "OK"

// The operation was cancelled, typically by the caller.
case cancelled = "CANCELLED"

// Unknown error.
case unknown = "UNKNOWN"

// The client specified an invalid argument.
case invalidArgument = "INVALID_ARGUMENT"

// The deadline expired before the operation could complete.
case deadlineExceeded = "DEADLINE_EXCEEDED"

// Some requested entity (e.g., file or directory) was not found.
case notFound = "NOT_FOUND"

// The entity that a client attempted to create (e.g., file or directory) already exists.
case alreadyExists = "ALREADY_EXISTS"

// The caller does not have permission to execute the specified operation.
case permissionDenied = "PERMISSION_DENIED"

// The request does not have valid authentication credentials for the operation.
case unauthenticated = "UNAUTHENTICATED"

// Some resource has been exhausted, perhaps a per-user quota, or perhaps the entire file system
// is out of space.
case resourceExhausted = "RESOURCE_EXHAUSTED"

// The operation was rejected because the system is not in a state required for the operation's
// execution.
case failedPrecondition = "FAILED_PRECONDITION"

// The operation was aborted, typically due to a concurrency issue such as a sequencer check
// failure or transaction abort.
case aborted = "ABORTED"

// The operation was attempted past the valid range.
case outOfRange = "OUT_OF_RANGE"

private var errorInfo: ErrorDetails? {
return details.first { $0.isErrorInfo() }
// The operation is not implemented or is not supported/enabled in this service.
case unimplemented = "UNIMPLEMENTED"

// Internal errors.
case internalError = "INTERNAL"

// The service is currently unavailable.
case unavailable = "UNAVAILABLE"

// Unrecoverable data loss or corruption.
case dataLoss = "DATA_LOSS"
}

init(httpResponseCode: Int, message: String, status: RPCStatus, details: [ErrorDetails]) {
self.httpResponseCode = httpResponseCode
let code: Int
let message: String
let status: Status
let details: [ErrorDetails]

init(httpResponseCode: Int, message: String, status: Status, details: [ErrorDetails]) {
code = httpResponseCode
self.message = message
self.status = status
self.details = details
}

func isInvalidAPIKeyError() -> Bool {
return errorInfo?.reason == "API_KEY_INVALID"
return details.contains { errorDetails in
switch errorDetails {
case let .errorInfo(errorInfo):
return errorInfo.reason == "API_KEY_INVALID"
default:
return false
}
}
}

func isUnsupportedUserLocationError() -> Bool {
return message == RPCErrorMessage.unsupportedUserLocation.rawValue
}
}

enum InvalidCandidateError: Error {
case emptyContent(underlyingError: Error)
case malformedContent(underlyingError: Error)
}

// MARK: - Decodable Conformance

extension RPCError: Decodable {
enum CodingKeys: CodingKey {
case error
}

struct ErrorStatus {
let code: Int?
let message: String?
let status: RPCError.Status?
let details: [RPCError.ErrorDetails]
}

init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
let status = try container.decode(ErrorStatus.self, forKey: .error)

if let code = status.code {
httpResponseCode = code
self.code = code
} else {
httpResponseCode = -1
code = -1
}

if let message = status.message {
Expand All @@ -71,34 +170,7 @@ extension RPCError: Decodable {
}
}

struct ErrorStatus {
let code: Int?
let message: String?
let status: RPCStatus?
let details: [ErrorDetails]
}

struct ErrorDetails {
static let errorInfoType = "type.googleapis.com/google.rpc.ErrorInfo"

let type: String
let reason: String?
let domain: String?

func isErrorInfo() -> Bool {
return type == ErrorDetails.errorInfoType
}
}

extension ErrorDetails: Decodable, Equatable {
enum CodingKeys: String, CodingKey {
case type = "@type"
case reason
case domain
}
}

extension ErrorStatus: Decodable {
extension RPCError.ErrorStatus: Decodable {
enum CodingKeys: CodingKey {
case code
case message
Expand All @@ -111,79 +183,63 @@ extension ErrorStatus: Decodable {
code = try container.decodeIfPresent(Int.self, forKey: .code)
message = try container.decodeIfPresent(String.self, forKey: .message)
do {
status = try container.decodeIfPresent(RPCStatus.self, forKey: .status)
status = try container.decodeIfPresent(RPCError.Status.self, forKey: .status)
} catch {
status = .unknown
}
if container.contains(.details) {
details = try container.decode([ErrorDetails].self, forKey: .details)
details = try container.decode([RPCError.ErrorDetails].self, forKey: .details)
} else {
details = []
}
}
}

enum RPCStatus: String, Decodable {
// Not an error; returned on success.
case ok = "OK"

// The operation was cancelled, typically by the caller.
case cancelled = "CANCELLED"

// Unknown error.
case unknown = "UNKNOWN"

// The client specified an invalid argument.
case invalidArgument = "INVALID_ARGUMENT"

// The deadline expired before the operation could complete.
case deadlineExceeded = "DEADLINE_EXCEEDED"

// Some requested entity (e.g., file or directory) was not found.
case notFound = "NOT_FOUND"

// The entity that a client attempted to create (e.g., file or directory) already exists.
case alreadyExists = "ALREADY_EXISTS"

// The caller does not have permission to execute the specified operation.
case permissionDenied = "PERMISSION_DENIED"

// The request does not have valid authentication credentials for the operation.
case unauthenticated = "UNAUTHENTICATED"

// Some resource has been exhausted, perhaps a per-user quota, or perhaps the entire file system
// is out of space.
case resourceExhausted = "RESOURCE_EXHAUSTED"

// The operation was rejected because the system is not in a state required for the operation's
// execution.
case failedPrecondition = "FAILED_PRECONDITION"

// The operation was aborted, typically due to a concurrency issue such as a sequencer check
// failure or transaction abort.
case aborted = "ABORTED"

// The operation was attempted past the valid range.
case outOfRange = "OUT_OF_RANGE"

// The operation is not implemented or is not supported/enabled in this service.
case unimplemented = "UNIMPLEMENTED"
extension RPCError.ErrorDetails: Decodable {
enum CodingKeys: String, CodingKey {
case type = "@type"
}

// Internal errors.
case internalError = "INTERNAL"
init(from decoder: Decoder) throws {
let errorDetailsContainer = try decoder.container(keyedBy: CodingKeys.self)
let type = try errorDetailsContainer.decode(String.self, forKey: .type)
if type == BadRequest.type {
let badRequestContainer = try decoder.singleValueContainer()
let badRequest = try badRequestContainer.decode(BadRequest.self)
self = RPCError.ErrorDetails.badRequest(badRequest)
} else if type == ErrorInfo.type {
let errorInfoContainer = try decoder.singleValueContainer()
let errorInfo = try errorInfoContainer.decode(ErrorInfo.self)
self = RPCError.ErrorDetails.errorInfo(errorInfo)
} else {
self = RPCError.ErrorDetails.unknown(type)
}
}
}

// The service is currently unavailable.
case unavailable = "UNAVAILABLE"
extension RPCError.ErrorDetails.BadRequest: Decodable {
enum CodingKeys: String, CodingKey {
case type = "@type"
case fieldViolations
}

// Unrecoverable data loss or corruption.
case dataLoss = "DATA_LOSS"
init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
type = try container.decode(String.self, forKey: .type)
fieldViolations = try container.decode([FieldViolation].self, forKey: .fieldViolations)
}
}

enum RPCErrorMessage: String {
case unsupportedUserLocation = "User location is not supported for the API use."
extension RPCError.ErrorDetails.ErrorInfo: Decodable {
enum CodingKeys: String, CodingKey {
case type = "@type"
case reason
case domain
}
}

enum InvalidCandidateError: Error {
case emptyContent(underlyingError: Error)
case malformedContent(underlyingError: Error)
// MARK: - Private

private enum RPCErrorMessage: String {
case unsupportedUserLocation = "User location is not supported for the API use."
}
8 changes: 4 additions & 4 deletions Tests/GoogleAITests/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ final class GenerativeModelTests: XCTestCase {
XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
} catch let GenerateContentError.internalError(underlying: rpcError as RPCError) {
XCTAssertEqual(rpcError.status, .invalidArgument)
XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
XCTAssertEqual(rpcError.code, expectedStatusCode)
XCTAssertEqual(rpcError.message, "Request contains an invalid argument.")
} catch {
XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
Expand Down Expand Up @@ -335,7 +335,7 @@ final class GenerativeModelTests: XCTestCase {
XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
} catch let GenerateContentError.internalError(underlying: rpcError as RPCError) {
XCTAssertEqual(rpcError.status, .notFound)
XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode)
XCTAssertEqual(rpcError.code, expectedStatusCode)
XCTAssertTrue(rpcError.message.hasPrefix("models/unknown is not found"))
} catch {
XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)")
Expand Down Expand Up @@ -671,7 +671,7 @@ final class GenerativeModelTests: XCTestCase {
responseCount += 1
}
} catch let GenerateContentError.internalError(rpcError as RPCError) {
XCTAssertEqual(rpcError.httpResponseCode, 499)
XCTAssertEqual(rpcError.code, 499)
XCTAssertEqual(rpcError.status, .cancelled)

// Check the content count is correct.
Expand Down Expand Up @@ -815,7 +815,7 @@ final class GenerativeModelTests: XCTestCase {
_ = try await model.countTokens("Why is the sky blue?")
XCTFail("Request should not have succeeded.")
} catch let CountTokensError.internalError(rpcError as RPCError) {
XCTAssertEqual(rpcError.httpResponseCode, 404)
XCTAssertEqual(rpcError.code, 404)
XCTAssertEqual(rpcError.status, .notFound)
XCTAssert(rpcError.message.hasPrefix("models/test-model-name is not found"))
return
Expand Down

0 comments on commit 8e613ec

Please sign in to comment.