Skip to content

Commit

Permalink
Support model names prefixed with "models/" (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored Dec 15, 2023
1 parent 9d68694 commit fead303
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 8 deletions.
26 changes: 19 additions & 7 deletions Sources/GoogleAI/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ import Foundation
/// A type that represents a remote multimodal model (like Gemini), with the ability to generate
/// content based on various input types.
public final class GenerativeModel {
/// Name of the model in the backend.
private let modelName: String
// The prefix for a model resource in the Gemini API.
private static let modelResourcePrefix = "models/"

/// The resource name of the model in the backend; has the format "models/model-name".
private let modelResourceName: String

/// The backing service responsible for sending and receiving model requests to the backend.
let generativeAIService: GenerativeAIService
Expand All @@ -31,7 +34,7 @@ public final class GenerativeModel {

/// Initializes a new remote model with the given parameters.
///
/// - Parameter name: The name of the model to be used.
/// - Parameter name: The name of the model to be used, e.g., "gemini-pro" or "models/gemini-pro".
/// - Parameter apiKey: The API key for your project.
/// - Parameter generationConfig: A value containing the content generation parameters your model
/// should use.
Expand All @@ -56,7 +59,7 @@ public final class GenerativeModel {
generationConfig: GenerationConfig? = nil,
safetySettings: [SafetySetting]? = nil,
urlSession: URLSession) {
modelName = name
modelResourceName = GenerativeModel.modelResourceName(name: name)
generativeAIService = GenerativeAIService(apiKey: apiKey, urlSession: urlSession)
self.generationConfig = generationConfig
self.safetySettings = safetySettings
Expand Down Expand Up @@ -96,7 +99,7 @@ public final class GenerativeModel {
/// - Returns: The generated content response from the model.
/// - Throws: A ``GenerateContentError`` if the request failed.
public func generateContent(_ content: [ModelContent]) async throws -> GenerateContentResponse {
let generateContentRequest = GenerateContentRequest(model: "models/\(modelName)",
let generateContentRequest = GenerateContentRequest(model: modelResourceName,
contents: content,
generationConfig: generationConfig,
safetySettings: safetySettings,
Expand Down Expand Up @@ -147,7 +150,7 @@ public final class GenerativeModel {
/// error if an error occurred.
public func generateContentStream(_ content: [ModelContent])
-> AsyncThrowingStream<GenerateContentResponse, Error> {
let generateContentRequest = GenerateContentRequest(model: "models/\(modelName)",
let generateContentRequest = GenerateContentRequest(model: modelResourceName,
contents: content,
generationConfig: generationConfig,
safetySettings: safetySettings,
Expand Down Expand Up @@ -215,14 +218,23 @@ public final class GenerativeModel {
/// - Throws: A ``CountTokensError`` if the tokenization request failed.
public func countTokens(_ content: [ModelContent]) async throws
-> CountTokensResponse {
let countTokensRequest = CountTokensRequest(model: "models/\(modelName)", contents: content)
let countTokensRequest = CountTokensRequest(model: modelResourceName, contents: content)

do {
return try await generativeAIService.loadRequest(request: countTokensRequest)
} catch {
throw CountTokensError.internalError(underlying: error)
}
}

/// Returns a model resource name of the form "models/model-name" based on `name`.
private static func modelResourceName(name: String) -> String {
if name.hasPrefix(modelResourcePrefix) {
return name
} else {
return modelResourcePrefix + name
}
}
}

/// See ``GenerativeModel/countTokens(_:)-9spwl``.
Expand Down
27 changes: 26 additions & 1 deletion Tests/GoogleAITests/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,22 @@ final class GenerativeModelTests: XCTestCase {
XCTAssertEqual(response.promptFeedback?.safetyRatings, expectedSafetyRatings)
}

func testGenerateContent_success_prefixedModelName() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "unary-success-basic-reply-short",
withExtension: "json"
)
let model = GenerativeModel(
// Model name is prefixed with "models/".
name: "models/test-model",
apiKey: "API_KEY",
urlSession: urlSession
)

_ = try await model.generateContent(testPrompt)
}

func testGenerateContent_failure_invalidAPIKey() async throws {
let expectedStatusCode = 400
MockURLProtocol
Expand Down Expand Up @@ -731,8 +747,10 @@ final class GenerativeModelTests: XCTestCase {
)) {
let fileURL = try XCTUnwrap(Bundle.module.url(forResource: name, withExtension: ext))
return { request in
let requestURL = try XCTUnwrap(request.url)
XCTAssertEqual(requestURL.path.occurrenceCount(of: "models/"), 1)
let response = try XCTUnwrap(HTTPURLResponse(
url: request.url!,
url: requestURL,
statusCode: statusCode,
httpVersion: nil,
headerFields: nil
Expand All @@ -741,3 +759,10 @@ final class GenerativeModelTests: XCTestCase {
}
}
}

private extension String {
/// Returns the number of occurrences of `substring` in the `String`.
func occurrenceCount(of substring: String) -> Int {
return components(separatedBy: substring).count - 1
}
}

0 comments on commit fead303

Please sign in to comment.