Skip to content

Commit

Permalink
[Vertex AI] Add Imagen integration tests for GCS and filtering (#14403)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored Jan 30, 2025
1 parent ed10c4f commit 7cbcc47
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 38 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import FirebaseAuth
import FirebaseCore
import FirebaseStorage
import FirebaseVertexAI
import Testing
import VertexAITestApp

#if canImport(UIKit)
import UIKit
#endif // canImport(UIKit)

@Suite(
.enabled(
if: ProcessInfo.processInfo.environment["VTXIntegrationImagen"] != nil,
"Only runs if the environment variable VTXIntegrationImagen is set."
),
.serialized
)
struct ImagenIntegrationTests {
var vertex: VertexAI
var storage: Storage
var userID1: String

init() async throws {
let authResult = try await Auth.auth().signIn(
withEmail: Credentials.emailAddress1,
password: Credentials.emailPassword1
)
userID1 = authResult.user.uid

vertex = VertexAI.vertexAI()
storage = Storage.storage()
}

@Test func generateImage_inlineImage() async throws {
let generationConfig = ImagenGenerationConfig(
negativePrompt: "snow, frost",
aspectRatio: .portrait3x4,
imageFormat: .png(),
addWatermark: false
)
let model = vertex.imagenModel(
modelName: "imagen-3.0-generate-002",
generationConfig: generationConfig,
safetySettings: ImagenSafetySettings(
safetyFilterLevel: .blockLowAndAbove,
personFilterLevel: .allowAdult
)
)
let imagePrompt = "A woman, 35mm portrait, in front of a mountain range"

let response = try await model.generateImages(prompt: imagePrompt)

#expect(response.filteredReason == nil)
#expect(response.images.count == 1)
let image = try #require(response.images.first)
#expect(image.mimeType == "image/png")
#expect(image.data.isEmpty == false)
#if canImport(UIKit)
let uiImage = try #require(UIImage(data: image.data))
#expect(uiImage.size.width == 896.0)
#expect(uiImage.size.height == 1280.0)
#endif // canImport(UIKit)
}

@Test func generateImages_gcsImages() async throws {
let generationConfig = ImagenGenerationConfig(
numberOfImages: 3,
aspectRatio: .landscape16x9,
imageFormat: .jpeg(compressionQuality: 60),
addWatermark: true
)
let model = vertex.imagenModel(
modelName: "imagen-3.0-fast-generate-001",
generationConfig: generationConfig,
safetySettings: ImagenSafetySettings(
safetyFilterLevel: .blockMediumAndAbove,
personFilterLevel: .blockAll
)
)
let prompt = "A dense jungle with light streaming through the treetops"
let storageRef = storage.reference(
withPath: "/vertexai/imagen/authenticated/user/\(userID1)"
)

let response = try await model.generateImages(prompt: prompt, gcsUri: storageRef.gsURI)

#expect(response.filteredReason == nil)
#expect(response.images.count == generationConfig.numberOfImages)
for image in response.images {
#expect(image.mimeType == "image/jpeg")
let imageRef = storage.reference(forURL: image.gcsURI)
let imageData = try await imageRef.data(maxSize: 1_000_000) // ~1MB
#expect(imageData.isEmpty == false)
#if canImport(UIKit)
let uiImage = try #require(UIImage(data: imageData))
#expect(uiImage.size.width == 1408.0)
#expect(uiImage.size.height == 768.0)
#endif // canImport(UIKit)
try await imageRef.delete()
}
}

@Test func generateImage_allImagesFilteredOut() async throws {
let generationConfig = ImagenGenerationConfig(numberOfImages: 2, imageFormat: .jpeg())
let model = vertex.imagenModel(
modelName: "imagen-3.0-fast-generate-001",
generationConfig: generationConfig,
safetySettings: ImagenSafetySettings(
safetyFilterLevel: .blockLowAndAbove,
personFilterLevel: .blockAll
)
)
let imagePrompt = "A woman, 35mm portrait, in front of a mountain range"

let response = try await model.generateImages(prompt: imagePrompt)

#expect(response.images.isEmpty)
let filteredReason = try #require(response.filteredReason)
// 39322892: Detects a person or face when it isn't allowed due to the request safety settings.
#expect(filteredReason.contains("39322892"))
// TODO(#14221): Update implementation and test to throw an exception when all filtered out.
}

// TODO(#14221): Add an integration test for the prompt being blocked.

// TODO(#14221): Add integration tests for validating that Storage Rules are enforced.
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,8 @@ final class IntegrationTests: XCTestCase {
SafetySetting(harmCategory: .civicIntegrity, threshold: .blockLowAndAbove),
]

let imagenGenerationConfig = ImagenGenerationConfig(
aspectRatio: .landscape16x9,
imageFormat: .jpeg(compressionQuality: 70)
)

var vertex: VertexAI!
var model: GenerativeModel!
var imagenModel: ImagenModel!
var storage: Storage!
var userID1 = ""

Expand All @@ -66,14 +60,6 @@ final class IntegrationTests: XCTestCase {
toolConfig: .init(functionCallingConfig: .none()),
systemInstruction: systemInstruction
)
imagenModel = vertex.imagenModel(
modelName: "imagen-3.0-fast-generate-001",
generationConfig: imagenGenerationConfig,
safetySettings: ImagenSafetySettings(
safetyFilterLevel: .blockLowAndAbove,
personFilterLevel: .blockAll
)
)

storage = Storage.storage()
}
Expand Down Expand Up @@ -249,30 +235,6 @@ final class IntegrationTests: XCTestCase {
XCTAssertTrue(String(describing: error).contains("Firebase App Check token is invalid"))
}
}

// MARK: - Imagen

func testGenerateImage_inlineData() async throws {
try IntegrationTestUtils.skipUnless(environmentVariable: "VTXIntegrationImagen")
let imagePrompt = """
A realistic photo of a male lion, mane thick and dark, standing proudly on a rocky outcrop
overlooking a vast African savanna at sunset. Golden hour light, long shadows, sharp focus on
the lion, shallow depth of field, detailed fur texture, DSLR, 85mm lens.
"""

let response = try await imagenModel.generateImages(prompt: imagePrompt)

XCTAssertNil(response.filteredReason)
XCTAssertEqual(response.images.count, 1)
let image = try XCTUnwrap(response.images.first)
XCTAssertEqual(image.mimeType, "image/jpeg")
XCTAssertGreaterThan(image.data.count, 0)
#if canImport(UIKit)
let uiImage = try XCTUnwrap(UIImage(data: image.data))
XCTAssertEqual(uiImage.size.width, 1408.0)
XCTAssertEqual(uiImage.size.height, 768.0)
#endif
}
}

extension StorageReference {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

/* Begin PBXBuildFile section */
862218812D04E098007ED2D4 /* IntegrationTestUtils.swift in Sources */ = {isa = PBXBuildFile; fileRef = 862218802D04E08D007ED2D4 /* IntegrationTestUtils.swift */; };
864F8F712D4980DD0002EA7E /* ImagenIntegrationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 864F8F702D4980D60002EA7E /* ImagenIntegrationTests.swift */; };
8661385C2CC943DD00F4B78E /* TestApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8661385B2CC943DD00F4B78E /* TestApp.swift */; };
8661385E2CC943DD00F4B78E /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8661385D2CC943DD00F4B78E /* ContentView.swift */; };
8661386E2CC943DE00F4B78E /* IntegrationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8661386D2CC943DE00F4B78E /* IntegrationTests.swift */; };
Expand Down Expand Up @@ -35,6 +36,7 @@

/* Begin PBXFileReference section */
862218802D04E08D007ED2D4 /* IntegrationTestUtils.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = IntegrationTestUtils.swift; sourceTree = "<group>"; };
864F8F702D4980D60002EA7E /* ImagenIntegrationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ImagenIntegrationTests.swift; sourceTree = "<group>"; };
866138582CC943DD00F4B78E /* VertexAITestApp-SPM.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = "VertexAITestApp-SPM.app"; sourceTree = BUILT_PRODUCTS_DIR; };
8661385B2CC943DD00F4B78E /* TestApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TestApp.swift; sourceTree = "<group>"; };
8661385D2CC943DD00F4B78E /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = "<group>"; };
Expand Down Expand Up @@ -124,6 +126,7 @@
children = (
868A7C4D2CCC1F4700E449DD /* Credentials.swift */,
8661386D2CC943DE00F4B78E /* IntegrationTests.swift */,
864F8F702D4980D60002EA7E /* ImagenIntegrationTests.swift */,
862218802D04E08D007ED2D4 /* IntegrationTestUtils.swift */,
);
path = Integration;
Expand Down Expand Up @@ -268,6 +271,7 @@
files = (
8698D7462CD3CF3600ABA833 /* FirebaseAppTestUtils.swift in Sources */,
868A7C4F2CCC229F00E449DD /* Credentials.swift in Sources */,
864F8F712D4980DD0002EA7E /* ImagenIntegrationTests.swift in Sources */,
862218812D04E098007ED2D4 /* IntegrationTestUtils.swift in Sources */,
8661386E2CC943DE00F4B78E /* IntegrationTests.swift in Sources */,
);
Expand Down

0 comments on commit 7cbcc47

Please sign in to comment.