Skip to content

Commit

Permalink
[Vertex AI] Update Imagen public APIs to match API proposal (#14388)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored Jan 28, 2025
1 parent d0e2014 commit ed10c4f
Show file tree
Hide file tree
Showing 11 changed files with 67 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import Foundation

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct ImagenFileDataImage {
public struct ImagenGCSImage {
public let mimeType: String
public let gcsURI: String

Expand All @@ -26,7 +26,7 @@ public struct ImagenFileDataImage {
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImagenFileDataImage: ImagenImageRepresentable {
extension ImagenGCSImage: ImagenImageRepresentable {
// TODO(andrewheard): Make this public when the SDK supports Imagen operations that take images as
// input (upscaling / editing).
var _internalImagenImage: _InternalImagenImage {
Expand All @@ -35,12 +35,12 @@ extension ImagenFileDataImage: ImagenImageRepresentable {
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImagenFileDataImage: Equatable {}
extension ImagenGCSImage: Equatable {}

// MARK: - Codable Conformances

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImagenFileDataImage: Decodable {
extension ImagenGCSImage: Decodable {
enum CodingKeys: String, CodingKey {
case mimeType
case gcsURI = "gcsUri"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct ImagenGenerationConfig {
public var numberOfImages: Int?
public var negativePrompt: String?
public var imageFormat: ImagenImageFormat?
public var numberOfImages: Int?
public var aspectRatio: ImagenAspectRatio?
public var imageFormat: ImagenImageFormat?
public var addWatermark: Bool?

public init(numberOfImages: Int? = nil, negativePrompt: String? = nil,
imageFormat: ImagenImageFormat? = nil, aspectRatio: ImagenAspectRatio? = nil,
public init(negativePrompt: String? = nil, numberOfImages: Int? = nil,
aspectRatio: ImagenAspectRatio? = nil, imageFormat: ImagenImageFormat? = nil,
addWatermark: Bool? = nil) {
self.numberOfImages = numberOfImages
self.negativePrompt = negativePrompt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import Foundation

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct ImagenInlineDataImage {
public struct ImagenInlineImage {
public let mimeType: String
public let data: Data

Expand All @@ -30,7 +30,7 @@ public struct ImagenInlineDataImage {
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImagenInlineDataImage: ImagenImageRepresentable {
extension ImagenInlineImage: ImagenImageRepresentable {
// TODO(andrewheard): Make this public when the SDK supports Imagen operations that take images as
// input (upscaling / editing).
var _internalImagenImage: _InternalImagenImage {
Expand All @@ -43,12 +43,12 @@ extension ImagenInlineDataImage: ImagenImageRepresentable {
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImagenInlineDataImage: Equatable {}
extension ImagenInlineImage: Equatable {}

// MARK: - Codable Conformances

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImagenInlineDataImage: Decodable {
extension ImagenInlineImage: Decodable {
enum CodingKeys: CodingKey {
case mimeType
case bytesBase64Encoded
Expand Down
16 changes: 9 additions & 7 deletions FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ public final class ImagenModel {
/// The backing service responsible for sending and receiving model requests to the backend.
let generativeAIService: GenerativeAIService

let generationConfig: ImagenGenerationConfig?

let safetySettings: ImagenSafetySettings?

/// Configuration parameters for sending requests to the backend.
Expand All @@ -32,6 +34,7 @@ public final class ImagenModel {
init(name: String,
projectID: String,
apiKey: String,
generationConfig: ImagenGenerationConfig?,
safetySettings: ImagenSafetySettings?,
requestOptions: RequestOptions,
appCheck: AppCheckInterop?,
Expand All @@ -45,13 +48,13 @@ public final class ImagenModel {
auth: auth,
urlSession: urlSession
)
self.generationConfig = generationConfig
self.safetySettings = safetySettings
self.requestOptions = requestOptions
}

public func generateImages(prompt: String,
generationConfig: ImagenGenerationConfig? = nil) async throws
-> ImagenGenerationResponse<ImagenInlineDataImage> {
public func generateImages(prompt: String) async throws
-> ImagenGenerationResponse<ImagenInlineImage> {
return try await generateImages(
prompt: prompt,
parameters: ImagenModel.imageGenerationParameters(
Expand All @@ -62,13 +65,12 @@ public final class ImagenModel {
)
}

public func generateImages(prompt: String, storageURI: String,
generationConfig: ImagenGenerationConfig? = nil) async throws
-> ImagenGenerationResponse<ImagenFileDataImage> {
public func generateImages(prompt: String, gcsUri: String) async throws
-> ImagenGenerationResponse<ImagenGCSImage> {
return try await generateImages(
prompt: prompt,
parameters: ImagenModel.imageGenerationParameters(
storageURI: storageURI,
storageURI: gcsUri,
generationConfig: generationConfig,
safetySettings: safetySettings
)
Expand Down
4 changes: 3 additions & 1 deletion FirebaseVertexAI/Sources/VertexAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,14 @@ public class VertexAI {
)
}

public func imagenModel(modelName: String, safetySettings: ImagenSafetySettings? = nil,
public func imagenModel(modelName: String, generationConfig: ImagenGenerationConfig? = nil,
safetySettings: ImagenSafetySettings? = nil,
requestOptions: RequestOptions = RequestOptions()) -> ImagenModel {
return ImagenModel(
name: modelResourceName(modelName: modelName),
projectID: projectID,
apiKey: apiKey,
generationConfig: generationConfig,
safetySettings: safetySettings,
requestOptions: requestOptions,
appCheck: appCheck,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ final class IntegrationTests: XCTestCase {
SafetySetting(harmCategory: .civicIntegrity, threshold: .blockLowAndAbove),
]

let imagenGenerationConfig = ImagenGenerationConfig(
aspectRatio: .landscape16x9,
imageFormat: .jpeg(compressionQuality: 70)
)

var vertex: VertexAI!
var model: GenerativeModel!
var imagenModel: ImagenModel!
Expand All @@ -63,6 +68,7 @@ final class IntegrationTests: XCTestCase {
)
imagenModel = vertex.imagenModel(
modelName: "imagen-3.0-fast-generate-001",
generationConfig: imagenGenerationConfig,
safetySettings: ImagenSafetySettings(
safetyFilterLevel: .blockLowAndAbove,
personFilterLevel: .blockAll
Expand Down Expand Up @@ -253,15 +259,8 @@ final class IntegrationTests: XCTestCase {
overlooking a vast African savanna at sunset. Golden hour light, long shadows, sharp focus on
the lion, shallow depth of field, detailed fur texture, DSLR, 85mm lens.
"""
let generationConfig = ImagenGenerationConfig(
imageFormat: .jpeg(compressionQuality: 70),
aspectRatio: .landscape16x9
)

let response = try await imagenModel.generateImages(
prompt: imagePrompt,
generationConfig: generationConfig
)
let response = try await imagenModel.generateImages(prompt: imagePrompt)

XCTAssertNil(response.filteredReason)
XCTAssertEqual(response.images.count, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ final class ImageGenerationParametersTests: XCTestCase {
let aspectRatio = ImagenAspectRatio.landscape16x9
let addWatermark = true
let generationConfig = ImagenGenerationConfig(
numberOfImages: sampleCount,
negativePrompt: negativePrompt,
imageFormat: imageFormat,
numberOfImages: sampleCount,
aspectRatio: aspectRatio,
imageFormat: imageFormat,
addWatermark: addWatermark
)
let expectedParameters = ImageGenerationParameters(
Expand Down Expand Up @@ -146,10 +146,10 @@ final class ImageGenerationParametersTests: XCTestCase {
let aspectRatio = ImagenAspectRatio.portrait3x4
let addWatermark = false
let generationConfig = ImagenGenerationConfig(
numberOfImages: sampleCount,
negativePrompt: negativePrompt,
imageFormat: imageFormat,
numberOfImages: sampleCount,
aspectRatio: aspectRatio,
imageFormat: imageFormat,
addWatermark: addWatermark
)
let safetyFilterLevel = ImagenSafetyFilterLevel.blockNone
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import XCTest
@testable import FirebaseVertexAI

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
final class ImagenFileDataImageTests: XCTestCase {
final class ImagenGCSImageTests: XCTestCase {
let decoder = JSONDecoder()

func testDecodeImage_gcsURI() throws {
Expand All @@ -31,7 +31,7 @@ final class ImagenFileDataImageTests: XCTestCase {
"""
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let image = try decoder.decode(ImagenFileDataImage.self, from: jsonData)
let image = try decoder.decode(ImagenGCSImage.self, from: jsonData)

XCTAssertEqual(image.mimeType, mimeType)
XCTAssertEqual(image.gcsURI, gcsURI)
Expand All @@ -49,10 +49,10 @@ final class ImagenFileDataImageTests: XCTestCase {
let jsonData = try XCTUnwrap(json.data(using: .utf8))

do {
_ = try decoder.decode(ImagenFileDataImage.self, from: jsonData)
_ = try decoder.decode(ImagenGCSImage.self, from: jsonData)
XCTFail("Expected an error; none thrown.")
} catch let DecodingError.keyNotFound(codingKey, _) {
let codingKey = try XCTUnwrap(codingKey as? ImagenFileDataImage.CodingKeys)
let codingKey = try XCTUnwrap(codingKey as? ImagenGCSImage.CodingKeys)
XCTAssertEqual(codingKey, .gcsURI)
} catch {
XCTFail("Expected a DecodingError.keyNotFound error; got \(error).")
Expand All @@ -68,10 +68,10 @@ final class ImagenFileDataImageTests: XCTestCase {
let jsonData = try XCTUnwrap(json.data(using: .utf8))

do {
_ = try decoder.decode(ImagenFileDataImage.self, from: jsonData)
_ = try decoder.decode(ImagenGCSImage.self, from: jsonData)
XCTFail("Expected an error; none thrown.")
} catch let DecodingError.keyNotFound(codingKey, _) {
let codingKey = try XCTUnwrap(codingKey as? ImagenFileDataImage.CodingKeys)
let codingKey = try XCTUnwrap(codingKey as? ImagenGCSImage.CodingKeys)
XCTAssertEqual(codingKey, .mimeType)
} catch {
XCTFail("Expected a DecodingError.keyNotFound error; got \(error).")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ final class ImagenGenerationRequestTests: XCTestCase {
}

func testInitializeRequest_inlineDataImage() throws {
let request = ImagenGenerationRequest<ImagenInlineDataImage>(
let request = ImagenGenerationRequest<ImagenInlineImage>(
model: modelName,
options: requestOptions,
instances: [instance],
Expand All @@ -62,7 +62,7 @@ final class ImagenGenerationRequestTests: XCTestCase {
}

func testInitializeRequest_fileDataImage() throws {
let request = ImagenGenerationRequest<ImagenFileDataImage>(
let request = ImagenGenerationRequest<ImagenGCSImage>(
model: modelName,
options: requestOptions,
instances: [instance],
Expand All @@ -82,7 +82,7 @@ final class ImagenGenerationRequestTests: XCTestCase {
// MARK: - Encoding Tests

func testEncodeRequest_inlineDataImage() throws {
let request = ImagenGenerationRequest<ImagenInlineDataImage>(
let request = ImagenGenerationRequest<ImagenInlineImage>(
model: modelName,
options: RequestOptions(),
instances: [instance],
Expand Down Expand Up @@ -110,7 +110,7 @@ final class ImagenGenerationRequestTests: XCTestCase {
}

func testEncodeRequest_fileDataImage() throws {
let request = ImagenGenerationRequest<ImagenFileDataImage>(
let request = ImagenGenerationRequest<ImagenGCSImage>(
model: modelName,
options: RequestOptions(),
instances: [instance],
Expand Down
Loading

0 comments on commit ed10c4f

Please sign in to comment.