From c9b336e33a329825c2aff7f137f33d476640114b Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Tue, 28 May 2024 11:23:02 -0400 Subject: [PATCH] Add `responseSchema` to `GenerationConfig` (#176) --- Sources/GoogleAI/GenerationConfig.swift | 11 ++++- .../GoogleAITests/GenerationConfigTests.swift | 42 +++++++++++++++++-- 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/Sources/GoogleAI/GenerationConfig.swift b/Sources/GoogleAI/GenerationConfig.swift index 850d1e9..c4bd37c 100644 --- a/Sources/GoogleAI/GenerationConfig.swift +++ b/Sources/GoogleAI/GenerationConfig.swift @@ -70,6 +70,12 @@ public struct GenerationConfig { /// - `application/json`: JSON response in the candidates. public let responseMIMEType: String? + /// Output response schema of the generated candidate text. + /// + /// - Note: This only applies when the specified ``responseMIMEType`` supports a schema; currently + /// this is limited to `application/json`. + public let responseSchema: Schema? + /// Creates a new `GenerationConfig` value. /// /// - Parameters: @@ -80,9 +86,11 @@ public struct GenerationConfig { /// - maxOutputTokens: See ``maxOutputTokens``. /// - stopSequences: See ``stopSequences``. /// - responseMIMEType: See ``responseMIMEType``. + /// - responseSchema: See ``responseSchema``. public init(temperature: Float? = nil, topP: Float? = nil, topK: Int? = nil, candidateCount: Int? = nil, maxOutputTokens: Int? = nil, - stopSequences: [String]? = nil, responseMIMEType: String? = nil) { + stopSequences: [String]? = nil, responseMIMEType: String? = nil, + responseSchema: Schema? = nil) { // Explicit init because otherwise if we re-arrange the above variables it changes the API // surface. self.temperature = temperature @@ -92,6 +100,7 @@ public struct GenerationConfig { self.maxOutputTokens = maxOutputTokens self.stopSequences = stopSequences self.responseMIMEType = responseMIMEType + self.responseSchema = responseSchema } } diff --git a/Tests/GoogleAITests/GenerationConfigTests.swift b/Tests/GoogleAITests/GenerationConfigTests.swift index 1ad3eaf..06dada7 100644 --- a/Tests/GoogleAITests/GenerationConfigTests.swift +++ b/Tests/GoogleAITests/GenerationConfigTests.swift @@ -48,7 +48,11 @@ final class GenerationConfigTests: XCTestCase { let candidateCount = 2 let maxOutputTokens = 256 let stopSequences = ["END", "DONE"] - let responseMIMEType = "text/plain" + let responseMIMEType = "application/json" + let schemaType = DataType.object + let fieldName = "test-field" + let fieldType = DataType.string + let responseSchema = Schema(type: schemaType, properties: [fieldName: Schema(type: fieldType)]) let generationConfig = GenerationConfig( temperature: temperature, topP: topP, @@ -56,7 +60,8 @@ final class GenerationConfigTests: XCTestCase { candidateCount: candidateCount, maxOutputTokens: maxOutputTokens, stopSequences: stopSequences, - responseMIMEType: responseMIMEType + responseMIMEType: responseMIMEType, + responseSchema: responseSchema ) let jsonData = try encoder.encode(generationConfig) @@ -67,6 +72,14 @@ final class GenerationConfigTests: XCTestCase { "candidateCount" : \(candidateCount), "maxOutputTokens" : \(maxOutputTokens), "responseMIMEType" : "\(responseMIMEType)", + "responseSchema" : { + "properties" : { + "\(fieldName)" : { + "type" : "\(fieldType.rawValue)" + } + }, + "type" : "\(schemaType.rawValue)" + }, "stopSequences" : [ "END", "DONE" @@ -79,7 +92,7 @@ final class GenerationConfigTests: XCTestCase { } func testEncodeGenerationConfig_responseMIMEType() throws { - let mimeType = "image/jpeg" + let mimeType = "text/plain" let generationConfig = GenerationConfig(responseMIMEType: mimeType) let jsonData = try encoder.encode(generationConfig) @@ -91,4 +104,27 @@ final class GenerationConfigTests: XCTestCase { } """) } + + func testEncodeGenerationConfig_responseMIMETypeWithSchema() throws { + let mimeType = "application/json" + let schemaType = DataType.array + let arrayItemType = DataType.integer + let schema = Schema(type: schemaType, items: Schema(type: arrayItemType)) + let generationConfig = GenerationConfig(responseMIMEType: mimeType, responseSchema: schema) + + let jsonData = try encoder.encode(generationConfig) + + let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) + XCTAssertEqual(json, """ + { + "responseMIMEType" : "\(mimeType)", + "responseSchema" : { + "items" : { + "type" : "\(arrayItemType.rawValue)" + }, + "type" : "\(schemaType.rawValue)" + } + } + """) + } }