Skip to content

Commit

Permalink
Add RequestOptions to GenerativeModel constructor (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored Jan 31, 2024
1 parent dcbdb5e commit f029e13
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 6 deletions.
1 change: 1 addition & 0 deletions Sources/GoogleAI/CountTokensRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import Foundation
struct CountTokensRequest {
let model: String
let contents: [ModelContent]
let options: RequestOptions?
}

extension CountTokensRequest: Encodable {
Expand Down
1 change: 1 addition & 0 deletions Sources/GoogleAI/GenerateContentRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ struct GenerateContentRequest {
let generationConfig: GenerationConfig?
let safetySettings: [SafetySetting]?
let isStreaming: Bool
let options: RequestOptions?
}

extension GenerateContentRequest: Encodable {
Expand Down
17 changes: 17 additions & 0 deletions Sources/GoogleAI/GenerativeAIRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,21 @@ protocol GenerativeAIRequest: Encodable {
associatedtype Response: Decodable

var url: URL { get }

var options: RequestOptions? { get }
}

/// Configuration parameters for sending requests to the backend.
public struct RequestOptions {
/// The request’s timeout interval in seconds; if not specified uses the default value for a
/// `URLRequest`.
let timeout: TimeInterval?

/// Initializes a request options object.
///
/// - Parameter timeout The request’s timeout interval in seconds; if not specified uses the
/// default value for a `URLRequest`.
public init(timeout: TimeInterval? = nil) {
self.timeout = timeout
}
}
4 changes: 4 additions & 0 deletions Sources/GoogleAI/GenerativeAIService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ struct GenerativeAIService {
encoder.keyEncodingStrategy = .convertToSnakeCase
urlRequest.httpBody = try encoder.encode(request)

if let timeoutInterval = request.options?.timeout {
urlRequest.timeoutInterval = timeoutInterval
}

return urlRequest
}

Expand Down
25 changes: 20 additions & 5 deletions Sources/GoogleAI/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ public final class GenerativeModel {
/// The safety settings to be used for prompts.
let safetySettings: [SafetySetting]?

/// Configuration parameters for sending requests to the backend.
let requestOptions: RequestOptions?

/// Initializes a new remote model with the given parameters.
///
/// - Parameter name: The name of the model to be used, e.g., "gemini-pro" or "models/gemini-pro".
Expand All @@ -40,15 +43,18 @@ public final class GenerativeModel {
/// should use.
/// - Parameter safetySettings: A value describing what types of harmful content your model
/// should allow.
/// - Parameter requestOptions Configuration parameters for sending requests to the backend.
public convenience init(name: String,
apiKey: String,
generationConfig: GenerationConfig? = nil,
safetySettings: [SafetySetting]? = nil) {
safetySettings: [SafetySetting]? = nil,
requestOptions: RequestOptions? = nil) {
self.init(
name: name,
apiKey: apiKey,
generationConfig: generationConfig,
safetySettings: safetySettings,
requestOptions: requestOptions,
urlSession: .shared
)
}
Expand All @@ -58,11 +64,13 @@ public final class GenerativeModel {
apiKey: String,
generationConfig: GenerationConfig? = nil,
safetySettings: [SafetySetting]? = nil,
requestOptions: RequestOptions? = nil,
urlSession: URLSession) {
modelResourceName = GenerativeModel.modelResourceName(name: name)
generativeAIService = GenerativeAIService(apiKey: apiKey, urlSession: urlSession)
self.generationConfig = generationConfig
self.safetySettings = safetySettings
self.requestOptions = requestOptions

Logging.default.info("""
[GoogleGenerativeAI] Model \(
Expand Down Expand Up @@ -98,12 +106,14 @@ public final class GenerativeModel {
/// - Parameter content: The input(s) given to the model as a prompt.
/// - Returns: The generated content response from the model.
/// - Throws: A ``GenerateContentError`` if the request failed.
public func generateContent(_ content: [ModelContent]) async throws -> GenerateContentResponse {
public func generateContent(_ content: [ModelContent]) async throws
-> GenerateContentResponse {
let generateContentRequest = GenerateContentRequest(model: modelResourceName,
contents: content,
generationConfig: generationConfig,
safetySettings: safetySettings,
isStreaming: false)
isStreaming: false,
options: requestOptions)
let response: GenerateContentResponse
do {
response = try await generativeAIService.loadRequest(request: generateContentRequest)
Expand Down Expand Up @@ -156,7 +166,8 @@ public final class GenerativeModel {
contents: content,
generationConfig: generationConfig,
safetySettings: safetySettings,
isStreaming: true)
isStreaming: true,
options: requestOptions)

var responseIterator = generativeAIService.loadRequestStream(request: generateContentRequest)
.makeAsyncIterator()
Expand Down Expand Up @@ -220,7 +231,11 @@ public final class GenerativeModel {
/// - Throws: A ``CountTokensError`` if the tokenization request failed.
public func countTokens(_ content: [ModelContent]) async throws
-> CountTokensResponse {
let countTokensRequest = CountTokensRequest(model: modelResourceName, contents: content)
let countTokensRequest = CountTokensRequest(
model: modelResourceName,
contents: content,
options: requestOptions
)

do {
return try await generativeAIService.loadRequest(request: countTokensRequest)
Expand Down
81 changes: 80 additions & 1 deletion Tests/GoogleAITests/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,27 @@ final class GenerativeModelTests: XCTestCase {
XCTAssertEqual(content.text, "This is the generated content.")
}

func testGenerateContent_requestOptions_customTimeout() async throws {
let expectedTimeout = 150.0
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "unary-success-basic-reply-short",
withExtension: "json",
timeout: expectedTimeout
)
let requestOptions = RequestOptions(timeout: expectedTimeout)
model = GenerativeModel(
name: "my-model",
apiKey: "API_KEY",
requestOptions: requestOptions,
urlSession: urlSession
)

let response = try await model.generateContent(testPrompt)

XCTAssertEqual(response.candidates.count, 1)
}

// MARK: - Generate Content (Streaming)

func testGenerateContentStream_failureInvalidAPIKey() async throws {
Expand Down Expand Up @@ -746,6 +767,32 @@ final class GenerativeModelTests: XCTestCase {
XCTFail("Expected an unsupported user location error.")
}

func testGenerateContentStream_requestOptions_customTimeout() async throws {
let expectedTimeout = 150.0
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "streaming-success-basic-reply-short",
withExtension: "txt",
timeout: expectedTimeout
)
let requestOptions = RequestOptions(timeout: expectedTimeout)
model = GenerativeModel(
name: "my-model",
apiKey: "API_KEY",
requestOptions: requestOptions,
urlSession: urlSession
)

var responses = 0
let stream = model.generateContentStream(testPrompt)
for try await content in stream {
XCTAssertNotNil(content.text)
responses += 1
}

XCTAssertEqual(responses, 1)
}

// MARK: - Count Tokens

func testCountTokens_succeeds() async throws {
Expand Down Expand Up @@ -777,6 +824,27 @@ final class GenerativeModelTests: XCTestCase {
XCTFail("Expected internal RPCError.")
}

func testCountTokens_requestOptions_customTimeout() async throws {
let expectedTimeout = 150.0
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "success-total-tokens",
withExtension: "json",
timeout: expectedTimeout
)
let requestOptions = RequestOptions(timeout: expectedTimeout)
model = GenerativeModel(
name: "my-model",
apiKey: "API_KEY",
requestOptions: requestOptions,
urlSession: urlSession
)

let response = try await model.countTokens(testPrompt)

XCTAssertEqual(response.totalTokens, 6)
}

// MARK: - Helpers

private func nonHTTPRequestHandler() throws -> ((URLRequest) -> (
Expand All @@ -797,14 +865,17 @@ final class GenerativeModelTests: XCTestCase {

private func httpRequestHandler(forResource name: String,
withExtension ext: String,
statusCode: Int = 200) throws -> ((URLRequest) throws -> (
statusCode: Int = 200,
timeout: TimeInterval = URLRequest
.defaultTimeoutInterval()) throws -> ((URLRequest) throws -> (
URLResponse,
AsyncLineSequence<URL.AsyncBytes>?
)) {
let fileURL = try XCTUnwrap(Bundle.module.url(forResource: name, withExtension: ext))
return { request in
let requestURL = try XCTUnwrap(request.url)
XCTAssertEqual(requestURL.path.occurrenceCount(of: "models/"), 1)
XCTAssertEqual(request.timeoutInterval, timeout)
let response = try XCTUnwrap(HTTPURLResponse(
url: requestURL,
statusCode: statusCode,
Expand All @@ -822,3 +893,11 @@ private extension String {
return components(separatedBy: substring).count - 1
}
}

private extension URLRequest {
/// Returns the default `timeoutInterval` for a `URLRequest`.
static func defaultTimeoutInterval() -> TimeInterval {
let placeholderURL = URL(string: "https://example.com")!
return URLRequest(url: placeholderURL).timeoutInterval
}
}

0 comments on commit f029e13

Please sign in to comment.