Skip to content

Commit

Permalink
[Vertex AI] Move ImagenModelConfig params to `ImagenGenerationConfi…
Browse files Browse the repository at this point in the history
…g` (#14340)
  • Loading branch information
andrewheard committed Jan 27, 2025
1 parent 362a01f commit d0e2014
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@
public struct ImagenGenerationConfig {
public var numberOfImages: Int?
public var negativePrompt: String?
public var imageFormat: ImagenImageFormat?
public var aspectRatio: ImagenAspectRatio?
public var addWatermark: Bool?

public init(numberOfImages: Int? = nil, negativePrompt: String? = nil,
aspectRatio: ImagenAspectRatio? = nil) {
imageFormat: ImagenImageFormat? = nil, aspectRatio: ImagenAspectRatio? = nil,
addWatermark: Bool? = nil) {
self.numberOfImages = numberOfImages
self.negativePrompt = negativePrompt
self.imageFormat = imageFormat
self.aspectRatio = aspectRatio
self.addWatermark = addWatermark
}
}
11 changes: 2 additions & 9 deletions FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ public final class ImagenModel {
/// The backing service responsible for sending and receiving model requests to the backend.
let generativeAIService: GenerativeAIService

let modelConfig: ImagenModelConfig?

let safetySettings: ImagenSafetySettings?

/// Configuration parameters for sending requests to the backend.
Expand All @@ -34,7 +32,6 @@ public final class ImagenModel {
init(name: String,
projectID: String,
apiKey: String,
modelConfig: ImagenModelConfig?,
safetySettings: ImagenSafetySettings?,
requestOptions: RequestOptions,
appCheck: AppCheckInterop?,
Expand All @@ -48,7 +45,6 @@ public final class ImagenModel {
auth: auth,
urlSession: urlSession
)
self.modelConfig = modelConfig
self.safetySettings = safetySettings
self.requestOptions = requestOptions
}
Expand All @@ -61,7 +57,6 @@ public final class ImagenModel {
parameters: ImagenModel.imageGenerationParameters(
storageURI: nil,
generationConfig: generationConfig,
modelConfig: modelConfig,
safetySettings: safetySettings
)
)
Expand All @@ -75,7 +70,6 @@ public final class ImagenModel {
parameters: ImagenModel.imageGenerationParameters(
storageURI: storageURI,
generationConfig: generationConfig,
modelConfig: modelConfig,
safetySettings: safetySettings
)
)
Expand All @@ -96,7 +90,6 @@ public final class ImagenModel {

static func imageGenerationParameters(storageURI: String?,
generationConfig: ImagenGenerationConfig?,
modelConfig: ImagenModelConfig?,
safetySettings: ImagenSafetySettings?)
-> ImageGenerationParameters {
return ImageGenerationParameters(
Expand All @@ -106,13 +99,13 @@ public final class ImagenModel {
aspectRatio: generationConfig?.aspectRatio?.rawValue,
safetyFilterLevel: safetySettings?.safetyFilterLevel?.rawValue,
personGeneration: safetySettings?.personFilterLevel?.rawValue,
outputOptions: modelConfig?.imageFormat.map {
outputOptions: generationConfig?.imageFormat.map {
ImageGenerationOutputOptions(
mimeType: $0.mimeType,
compressionQuality: $0.compressionQuality
)
},
addWatermark: modelConfig?.addWatermark,
addWatermark: generationConfig?.addWatermark,
includeResponsibleAIFilterReason: true
)
}
Expand Down

This file was deleted.

4 changes: 1 addition & 3 deletions FirebaseVertexAI/Sources/VertexAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,12 @@ public class VertexAI {
)
}

public func imagenModel(modelName: String, modelConfig: ImagenModelConfig? = nil,
safetySettings: ImagenSafetySettings? = nil,
public func imagenModel(modelName: String, safetySettings: ImagenSafetySettings? = nil,
requestOptions: RequestOptions = RequestOptions()) -> ImagenModel {
return ImagenModel(
name: modelResourceName(modelName: modelName),
projectID: projectID,
apiKey: apiKey,
modelConfig: modelConfig,
safetySettings: safetySettings,
requestOptions: requestOptions,
appCheck: appCheck,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ final class IntegrationTests: XCTestCase {
)
imagenModel = vertex.imagenModel(
modelName: "imagen-3.0-fast-generate-001",
modelConfig: ImagenModelConfig(imageFormat: .jpeg(compressionQuality: 70)),
safetySettings: ImagenSafetySettings(
safetyFilterLevel: .blockLowAndAbove,
personFilterLevel: .blockAll
Expand Down Expand Up @@ -254,7 +253,10 @@ final class IntegrationTests: XCTestCase {
overlooking a vast African savanna at sunset. Golden hour light, long shadows, sharp focus on
the lion, shallow depth of field, detailed fur texture, DSLR, 85mm lens.
"""
let generationConfig = ImagenGenerationConfig(aspectRatio: .landscape16x9)
let generationConfig = ImagenGenerationConfig(
imageFormat: .jpeg(compressionQuality: 70),
aspectRatio: .landscape16x9
)

let response = try await imagenModel.generateImages(
prompt: imagePrompt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ final class ImageGenerationParametersTests: XCTestCase {
let parameters = ImagenModel.imageGenerationParameters(
storageURI: nil,
generationConfig: nil,
modelConfig: nil,
safetySettings: nil
)

Expand All @@ -64,37 +63,6 @@ final class ImageGenerationParametersTests: XCTestCase {
let parameters = ImagenModel.imageGenerationParameters(
storageURI: storageURI,
generationConfig: nil,
modelConfig: nil,
safetySettings: nil
)

XCTAssertEqual(parameters, expectedParameters)
}

func testParameters_includeModelConfig() throws {
let compressionQuality = 80
let imageFormat = ImagenImageFormat.jpeg(compressionQuality: compressionQuality)
let addWatermark = true
let modelConfig = ImagenModelConfig(imageFormat: imageFormat, addWatermark: addWatermark)
let expectedParameters = ImageGenerationParameters(
sampleCount: 1,
storageURI: nil,
negativePrompt: nil,
aspectRatio: nil,
safetyFilterLevel: nil,
personGeneration: nil,
outputOptions: ImageGenerationOutputOptions(
mimeType: imageFormat.mimeType,
compressionQuality: imageFormat.compressionQuality
),
addWatermark: addWatermark,
includeResponsibleAIFilterReason: true
)

let parameters = ImagenModel.imageGenerationParameters(
storageURI: nil,
generationConfig: nil,
modelConfig: modelConfig,
safetySettings: nil
)

Expand All @@ -104,11 +72,16 @@ final class ImageGenerationParametersTests: XCTestCase {
func testParameters_includeGenerationConfig() throws {
let sampleCount = 2
let negativePrompt = "test-negative-prompt"
let compressionQuality = 80
let imageFormat = ImagenImageFormat.jpeg(compressionQuality: compressionQuality)
let aspectRatio = ImagenAspectRatio.landscape16x9
let addWatermark = true
let generationConfig = ImagenGenerationConfig(
numberOfImages: sampleCount,
negativePrompt: negativePrompt,
aspectRatio: aspectRatio
imageFormat: imageFormat,
aspectRatio: aspectRatio,
addWatermark: addWatermark
)
let expectedParameters = ImageGenerationParameters(
sampleCount: sampleCount,
Expand All @@ -117,15 +90,17 @@ final class ImageGenerationParametersTests: XCTestCase {
aspectRatio: aspectRatio.rawValue,
safetyFilterLevel: nil,
personGeneration: nil,
outputOptions: nil,
addWatermark: nil,
outputOptions: ImageGenerationOutputOptions(
mimeType: imageFormat.mimeType,
compressionQuality: imageFormat.compressionQuality
),
addWatermark: addWatermark,
includeResponsibleAIFilterReason: true
)

let parameters = ImagenModel.imageGenerationParameters(
storageURI: nil,
generationConfig: generationConfig,
modelConfig: nil,
safetySettings: nil
)

Expand Down Expand Up @@ -155,7 +130,6 @@ final class ImageGenerationParametersTests: XCTestCase {
let parameters = ImagenModel.imageGenerationParameters(
storageURI: nil,
generationConfig: nil,
modelConfig: nil,
safetySettings: safetySettings
)

Expand All @@ -168,15 +142,16 @@ final class ImageGenerationParametersTests: XCTestCase {
let storageURI = "gs://test-bucket/path"
let sampleCount = 4
let negativePrompt = "test-negative-prompt"
let imageFormat = ImagenImageFormat.png()
let aspectRatio = ImagenAspectRatio.portrait3x4
let addWatermark = false
let generationConfig = ImagenGenerationConfig(
numberOfImages: sampleCount,
negativePrompt: negativePrompt,
aspectRatio: aspectRatio
imageFormat: imageFormat,
aspectRatio: aspectRatio,
addWatermark: addWatermark
)
let imageFormat = ImagenImageFormat.png()
let addWatermark = false
let modelConfig = ImagenModelConfig(imageFormat: imageFormat, addWatermark: addWatermark)
let safetyFilterLevel = ImagenSafetyFilterLevel.blockNone
let personFilterLevel = ImagenPersonFilterLevel.blockAll
let safetySettings = ImagenSafetySettings(
Expand All @@ -201,7 +176,6 @@ final class ImageGenerationParametersTests: XCTestCase {
let parameters = ImagenModel.imageGenerationParameters(
storageURI: storageURI,
generationConfig: generationConfig,
modelConfig: modelConfig,
safetySettings: safetySettings
)

Expand Down

0 comments on commit d0e2014

Please sign in to comment.