From a788378a40fbed4392e6aa301edfbec6ead25c6c Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Wed, 17 Jan 2024 18:20:13 -0500 Subject: [PATCH] Add `invalidAPIKey` case to `GenerateContentError` --- Sources/GoogleAI/Errors.swift | 41 ++++++++++++++++++- Sources/GoogleAI/GenerateContentError.swift | 3 ++ Sources/GoogleAI/GenerativeAIService.swift | 33 +++++++-------- Sources/GoogleAI/GenerativeModel.swift | 3 ++ .../GoogleAITests/GenerativeModelTests.swift | 10 ++--- 5 files changed, 66 insertions(+), 24 deletions(-) diff --git a/Sources/GoogleAI/Errors.swift b/Sources/GoogleAI/Errors.swift index 4fca53c..eaed192 100644 --- a/Sources/GoogleAI/Errors.swift +++ b/Sources/GoogleAI/Errors.swift @@ -18,11 +18,21 @@ struct RPCError: Error { let httpResponseCode: Int let message: String let status: RPCStatus + let details: [ErrorDetails] - init(httpResponseCode: Int, message: String, status: RPCStatus) { + private var errorInfo: ErrorDetails? { + return details.first { $0.isErrorInfo() } + } + + init(httpResponseCode: Int, message: String, status: RPCStatus, details: [ErrorDetails]) { self.httpResponseCode = httpResponseCode self.message = message self.status = status + self.details = details + } + + func isInvalidAPIKeyError() -> Bool { + return errorInfo?.reason == "API_KEY_INVALID" } } @@ -52,6 +62,8 @@ extension RPCError: Decodable { } else { self.status = .unknown } + + details = status.details } } @@ -59,6 +71,27 @@ struct ErrorStatus { let code: Int? let message: String? let status: RPCStatus? + let details: [ErrorDetails] +} + +struct ErrorDetails { + static let errorInfoType = "type.googleapis.com/google.rpc.ErrorInfo" + + let type: String + let reason: String? + let domain: String? + + func isErrorInfo() -> Bool { + return type == ErrorDetails.errorInfoType + } +} + +extension ErrorDetails: Decodable, Equatable { + enum CodingKeys: String, CodingKey { + case type = "@type" + case reason + case domain + } } extension ErrorStatus: Decodable { @@ -66,6 +99,7 @@ extension ErrorStatus: Decodable { case code case message case status + case details } init(from decoder: Decoder) throws { @@ -77,6 +111,11 @@ extension ErrorStatus: Decodable { } catch { status = .unknown } + if container.contains(.details) { + details = try container.decode([ErrorDetails].self, forKey: .details) + } else { + details = [] + } } } diff --git a/Sources/GoogleAI/GenerateContentError.swift b/Sources/GoogleAI/GenerateContentError.swift index 38e6b92..dd1cb84 100644 --- a/Sources/GoogleAI/GenerateContentError.swift +++ b/Sources/GoogleAI/GenerateContentError.swift @@ -24,4 +24,7 @@ public enum GenerateContentError: Error { /// A response didn't fully complete. See the `FinishReason` for more information. case responseStoppedEarly(reason: FinishReason, response: GenerateContentResponse) + + /// The provided API key is invalid. + case invalidAPIKey(underlying: Error) } diff --git a/Sources/GoogleAI/GenerativeAIService.swift b/Sources/GoogleAI/GenerativeAIService.swift index 16551b4..95d39d8 100644 --- a/Sources/GoogleAI/GenerativeAIService.swift +++ b/Sources/GoogleAI/GenerativeAIService.swift @@ -45,7 +45,7 @@ struct GenerativeAIService { Logging.network.error("[GoogleGenerativeAI] Response payload: \(responseString)") } - throw try JSONDecoder().decode(RPCError.self, from: data) + throw parseError(responseData: data) } return try parseResponse(T.Response.self, from: data) @@ -96,11 +96,7 @@ struct GenerativeAIService { } Logging.network.error("[GoogleGenerativeAI] Response payload: \(responseBody)") - do { - try parseError(responseBody: responseBody) - } catch { - continuation.finish(throwing: error) - } + continuation.finish(throwing: parseError(responseBody: responseBody)) return } @@ -138,12 +134,7 @@ struct GenerativeAIService { } if extraLines.count > 0 { - do { - try parseError(responseBody: extraLines) - } catch { - continuation.finish(throwing: error) - } - + continuation.finish(throwing: parseError(responseBody: extraLines)) return } @@ -198,15 +189,21 @@ struct GenerativeAIService { return data } - private func parseError(responseBody: String) throws { - let data = try jsonData(jsonText: responseBody) + private func parseError(responseBody: String) -> Error { + do { + let data = try jsonData(jsonText: responseBody) + return parseError(responseData: data) + } catch { + return error + } + } + private func parseError(responseData: Data) -> Error { do { - let rpcError = try JSONDecoder().decode(RPCError.self, from: data) - throw rpcError + return try JSONDecoder().decode(RPCError.self, from: responseData) } catch { - // TODO: Throw an error about an unrecognized error payload with the response body - throw error + // TODO: Return an error about an unrecognized error payload with the response body + return error } } diff --git a/Sources/GoogleAI/GenerativeModel.swift b/Sources/GoogleAI/GenerativeModel.swift index 49b5796..6a35c06 100644 --- a/Sources/GoogleAI/GenerativeModel.swift +++ b/Sources/GoogleAI/GenerativeModel.swift @@ -108,6 +108,9 @@ public final class GenerativeModel { do { response = try await generativeAIService.loadRequest(request: generateContentRequest) } catch { + if let error = error as? RPCError, error.isInvalidAPIKeyError() { + throw GenerateContentError.invalidAPIKey(underlying: error) + } throw GenerateContentError.internalError(underlying: error) } diff --git a/Tests/GoogleAITests/GenerativeModelTests.swift b/Tests/GoogleAITests/GenerativeModelTests.swift index fd43bbb..4775456 100644 --- a/Tests/GoogleAITests/GenerativeModelTests.swift +++ b/Tests/GoogleAITests/GenerativeModelTests.swift @@ -181,12 +181,12 @@ final class GenerativeModelTests: XCTestCase { do { _ = try await model.generateContent(testPrompt) XCTFail("Should throw GenerateContentError.internalError; no error thrown.") - } catch let GenerateContentError.internalError(underlying: rpcError as RPCError) { - XCTAssertEqual(rpcError.status, .invalidArgument) - XCTAssertEqual(rpcError.httpResponseCode, expectedStatusCode) - XCTAssertTrue(rpcError.message.hasPrefix("API key not valid")) + } catch let GenerateContentError.invalidAPIKey(underlying: error as RPCError) { + XCTAssertEqual(error.status, .invalidArgument) + XCTAssertEqual(error.httpResponseCode, expectedStatusCode) + XCTAssertTrue(error.message.hasPrefix("API key not valid")) } catch { - XCTFail("Should throw GenerateContentError.internalError; error thrown: \(error)") + XCTFail("Should throw GenerateContentError.invalidAPIKey; error thrown: \(error)") } }