Skip to content

Commit

Permalink
Add tests for unknown enum values in responses (#60)
Browse files Browse the repository at this point in the history
andrewheard authored Dec 13, 2023
1 parent ea39d69 commit a9f807d
Showing 5 changed files with 170 additions and 58 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
{
"candidates": [
{
"content": {
"parts": [
{
"text": "Some text"
}
]
},
"finishReason": "FAKE_NEW_FINISH_REASON",
"index": 0,
"safetyRatings": [
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE"
}
]
}
],
"promptFeedback": {
"safetyRatings": [
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE"
}
]
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"promptFeedback": {
"blockReason": "FAKE_NEW_BLOCK_REASON",
"safetyRatings": [
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE"
}
]
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
{
"candidates": [
{
"content": {
"parts": [
{
"text": "Some text"
}
],
"role": "model"
},
"finishReason": "STOP",
"index": 0,
"safetyRatings": [
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "MEDIUM"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "FAKE_NEW_HARM_PROBABILITY"
},
{
"category": "FAKE_NEW_HARM_CATEGORY",
"probability": "HIGH"
}
]
}
],
"promptFeedback": {
"safetyRatings": [
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "MEDIUM"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "FAKE_NEW_HARM_PROBABILITY"
},
{
"category": "FAKE_NEW_HARM_CATEGORY",
"probability": "HIGH"
}
]
}
}

This file was deleted.

54 changes: 48 additions & 6 deletions Tests/GoogleAITests/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
@@ -134,17 +134,23 @@ final class GenerativeModelTests: XCTestCase {
XCTAssertEqual(promptFeedback.safetyRatings, safetyRatingsNegligible)
}

func testGenerateContent_success_unknownEnum() async throws {
func testGenerateContent_success_unknownEnum_safetyRatings() async throws {
let expectedSafetyRatings = [
SafetyRating(category: .harassment, probability: .medium),
SafetyRating(category: .dangerousContent, probability: .unknown),
SafetyRating(category: .unknown, probability: .high),
]
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "unary-success-unknown-enum",
forResource: "unary-success-unknown-enum-safety-ratings",
withExtension: "json"
)

let content = try await model.generateContent(testPrompt)
let response = try await model.generateContent(testPrompt)

XCTAssertNotNil(content.text)
// TODO: Add assertions
XCTAssertEqual(response.text, "Some text")
XCTAssertEqual(response.candidates.first?.safetyRatings, expectedSafetyRatings)
XCTAssertEqual(response.promptFeedback?.safetyRatings, expectedSafetyRatings)
}

func testGenerateContent_failure_invalidAPIKey() async throws {
@@ -261,7 +267,43 @@ final class GenerativeModelTests: XCTestCase {
} catch let GenerateContentError.promptBlocked(response) {
XCTAssertNil(response.text)
} catch {
XCTFail("Should throw a promptBlocked]")
XCTFail("Should throw a promptBlocked")
}
}

func testGenerateContent_failure_unknownEnum_finishReason() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "unary-failure-unknown-enum-finish-reason",
withExtension: "json"
)

do {
_ = try await model.generateContent(testPrompt)
XCTFail("Should throw")
} catch let GenerateContentError.responseStoppedEarly(reason, response) {
XCTAssertEqual(reason, .unknown)
XCTAssertEqual(response.text, "Some text")
} catch {
XCTFail("Should throw a responseStoppedEarly")
}
}

func testGenerateContent_failure_unknownEnum_promptBlocked() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "unary-failure-unknown-enum-prompt-blocked",
withExtension: "json"
)

do {
_ = try await model.generateContent(testPrompt)
XCTFail("Should throw")
} catch let GenerateContentError.promptBlocked(response) {
let promptFeedback = try XCTUnwrap(response.promptFeedback)
XCTAssertEqual(promptFeedback.blockReason, .unknown)
} catch {
XCTFail("Should throw a promptBlocked")
}
}

0 comments on commit a9f807d

Please sign in to comment.