Skip to content

Commit

Permalink
Wrap internal GenerativeModel
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard committed Mar 13, 2024
1 parent 9454225 commit bd5353c
Show file tree
Hide file tree
Showing 5 changed files with 288 additions and 121 deletions.
137 changes: 49 additions & 88 deletions Sources/GoogleAI/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,10 @@ public final class GenerativeModel {
// The prefix for a model resource in the Gemini API.
private static let modelResourcePrefix = "models/"

// The prefix for a tuned model resource in the Gemini API.
private static let tunedModelResourcePrefix = "tunedModels/"

/// The resource name of the model in the backend; has the format "models/model-name".
let modelResourceName: String

/// The backing service responsible for sending and receiving model requests to the backend.
let generativeAIService: GenerativeAIService

/// Configuration parameters used for the MultiModalModel.
let generationConfig: GenerationConfig?

/// The safety settings to be used for prompts.
let safetySettings: [SafetySetting]?

/// Configuration parameters for sending requests to the backend.
let requestOptions: RequestOptions
let internalModel: InternalGenerativeAI.GenerativeModel

/// Initializes a new remote model with the given parameters.
///
Expand Down Expand Up @@ -72,10 +59,6 @@ public final class GenerativeModel {
requestOptions: RequestOptions = RequestOptions(),
urlSession: URLSession) {
modelResourceName = GenerativeModel.modelResourceName(name: name)
generativeAIService = GenerativeAIService(apiKey: apiKey, urlSession: urlSession)
self.generationConfig = generationConfig
self.safetySettings = safetySettings
self.requestOptions = requestOptions

Logging.default.info("""
[GoogleGenerativeAI] Model \(
Expand All @@ -85,6 +68,15 @@ public final class GenerativeModel {
`\(Logging.enableArgumentKey, privacy: .public)` as a launch argument in Xcode.
""")
Logging.verbose.debug("[GoogleGenerativeAI] Verbose logging enabled.")

internalModel = InternalGenerativeAI.GenerativeModel(
modelResourceName: modelResourceName,
apiKey: apiKey,
generationConfig: generationConfig?.toInternal(),
safetySettings: safetySettings?.toInternal(),
requestOptions: requestOptions.toInternal(),
urlSession: urlSession
)
}

/// Generates content from String and/or image inputs, given to the model as a prompt, that are
Expand Down Expand Up @@ -114,36 +106,13 @@ public final class GenerativeModel {
/// - Throws: A ``GenerateContentError`` if the request failed.
public func generateContent(_ content: @autoclosure () throws -> [ModelContent]) async throws
-> GenerateContentResponse {
let response: GenerateContentResponse
do {
let generateContentRequest = try GenerateContentRequest(model: modelResourceName,
contents: content().toInternal(),
generationConfig: generationConfig?
.toInternal(),
safetySettings: safetySettings?
.toInternal(),
isStreaming: false,
options: requestOptions.toInternal())
response = try GenerateContentResponse(internalResponse: await generativeAIService
.loadRequest(request: generateContentRequest))
let evaluatedContent = try content()
return try await GenerateContentResponse(internalResponse: internalModel
.generateContent(evaluatedContent.toInternal()))
} catch {
if let imageError = error as? ImageConversionError {
throw GenerateContentError.promptImageContentError(underlying: imageError)
}
throw GenerativeModel.generateContentError(from: error)
}

// Check the prompt feedback to see if the prompt was blocked.
if response.promptFeedback?.blockReason != nil {
throw GenerateContentError.promptBlocked(response: response)
}

// Check to see if an error should be thrown for stop reason.
if let reason = response.candidates.first?.finishReason, reason != .stop {
throw GenerateContentError.responseStoppedEarly(reason: reason, response: response)
}

return response
}

/// Generates content from String and/or image inputs, given to the model as a prompt, that are
Expand Down Expand Up @@ -180,54 +149,20 @@ public final class GenerativeModel {
evaluatedContent = try content()
} catch let underlying {
return AsyncThrowingStream { continuation in
let error: Error
if let contentError = underlying as? ImageConversionError {
error = GenerateContentError.promptImageContentError(underlying: contentError)
} else {
error = GenerateContentError.internalError(underlying: underlying)
}
let error = GenerativeModel.generateContentError(from: underlying)
continuation.finish(throwing: error)
}
}

let generateContentRequest = GenerateContentRequest(model: modelResourceName,
contents: evaluatedContent.toInternal(),
generationConfig: generationConfig?
.toInternal(),
safetySettings: safetySettings?
.toInternal(),
isStreaming: true,
options: requestOptions.toInternal())

var responseIterator = generativeAIService.loadRequestStream(request: generateContentRequest)
.makeAsyncIterator()
var responseIterator = internalModel
.generateContentStream(evaluatedContent.toInternal()).makeAsyncIterator()
return AsyncThrowingStream {
let response: GenerateContentResponse?
do {
response = try await responseIterator.next()
return try await responseIterator.next()
.flatMap { GenerateContentResponse(internalResponse: $0) }
} catch {
throw GenerativeModel.generateContentError(from: error)
}

// The responseIterator will return `nil` when it's done.
guard let response = response else {
// This is the end of the stream! Signal it by sending `nil`.
return nil
}

// Check the prompt feedback to see if the prompt was blocked.
if response.promptFeedback?.blockReason != nil {
throw GenerateContentError.promptBlocked(response: response)
}

// If the stream ended early unexpectedly, throw an error.
if let finishReason = response.candidates.first?.finishReason, finishReason != .stop {
throw GenerateContentError.responseStoppedEarly(reason: finishReason, response: response)
} else {
// Response was valid content, pass it along and continue.
return response
}
}
}

Expand Down Expand Up @@ -266,13 +201,13 @@ public final class GenerativeModel {
public func countTokens(_ content: @autoclosure () throws -> [ModelContent]) async throws
-> CountTokensResponse {
do {
let countTokensRequest = try CountTokensRequest(
model: modelResourceName,
contents: content().toInternal(),
options: requestOptions.toInternal()
)
let internalResponse = try await generativeAIService.loadRequest(request: countTokensRequest)
let internalResponse = try await internalModel.countTokens(content().toInternal())
return CountTokensResponse(internalResponse: internalResponse)
} catch let error as InternalGenerativeAI.CountTokensError {
switch error {
case let .internalError(underlying: underlying):
throw CountTokensError.internalError(underlying: underlying)
}
} catch {
throw CountTokensError.internalError(underlying: error)
}
Expand All @@ -293,6 +228,32 @@ public final class GenerativeModel {
private static func generateContentError(from error: Error) -> GenerateContentError {
if let error = error as? GenerateContentError {
return error
} else if let error = error as? InternalGenerativeAI.GenerateContentError {
switch error {
case let .internalError(underlying: underlying):
return GenerateContentError.internalError(underlying: underlying)
case let .promptBlocked(response: response):
return GenerateContentError
.promptBlocked(response: GenerateContentResponse(internalResponse: response))
case let .responseStoppedEarly(reason: reason, response: response):
return GenerateContentError.responseStoppedEarly(
reason: FinishReason(internalReason: reason),
response: GenerateContentResponse(internalResponse: response)
)
}
} else if let error = error as? InternalGenerativeAI.InvalidCandidateError {
switch error {
case let .emptyContent(underlyingError):
return GenerateContentError
.internalError(underlying: InvalidCandidateError
.emptyContent(underlyingError: underlyingError))
case let .malformedContent(underlyingError):
return GenerateContentError
.internalError(underlying: InvalidCandidateError
.malformedContent(underlyingError: underlyingError))
}
} else if let error = error as? ImageConversionError {
return GenerateContentError.promptImageContentError(underlying: error)
} else if let error = error as? RPCError, error.isInvalidAPIKeyError() {
return GenerateContentError.invalidAPIKey
} else if let error = error as? RPCError, error.isUnsupportedUserLocationError() {
Expand Down
2 changes: 1 addition & 1 deletion Sources/Internal/Errors.swift
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ public enum RPCStatus: String, Decodable {
case dataLoss = "DATA_LOSS"
}

enum InvalidCandidateError: Error {
public enum InvalidCandidateError: Error {
case emptyContent(underlyingError: Error)
case malformedContent(underlyingError: Error)
}
28 changes: 28 additions & 0 deletions Sources/Internal/GenerateContentError.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation

/// Errors that occur when generating content from a model.
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
public enum GenerateContentError: Error {
/// An internal error occurred. See the underlying error for more context.
case internalError(underlying: Error)

/// A prompt was blocked. See the response's `promptFeedback.blockReason` for more information.
case promptBlocked(response: GenerateContentResponse)

/// A response didn't fully complete. See the `FinishReason` for more information.
case responseStoppedEarly(reason: FinishReason, response: GenerateContentResponse)
}
Loading

0 comments on commit bd5353c

Please sign in to comment.