From 3d265048d3735a364f137167511cf830362b4adc Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Wed, 10 Apr 2024 17:57:36 -0400 Subject: [PATCH] Add system instruction support (#129) --- Sources/GoogleAI/GenerateContentRequest.swift | 2 ++ Sources/GoogleAI/GenerativeModel.swift | 11 +++++++++++ Tests/GoogleAITests/GoogleAITests.swift | 9 ++++++++- 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/Sources/GoogleAI/GenerateContentRequest.swift b/Sources/GoogleAI/GenerateContentRequest.swift index e0e9c2f..05abadf 100644 --- a/Sources/GoogleAI/GenerateContentRequest.swift +++ b/Sources/GoogleAI/GenerateContentRequest.swift @@ -23,6 +23,7 @@ struct GenerateContentRequest { let safetySettings: [SafetySetting]? let tools: [Tool]? let toolConfig: ToolConfig? + let systemInstruction: ModelContent? let isStreaming: Bool let options: RequestOptions } @@ -35,6 +36,7 @@ extension GenerateContentRequest: Encodable { case safetySettings case tools case toolConfig + case systemInstruction } } diff --git a/Sources/GoogleAI/GenerativeModel.swift b/Sources/GoogleAI/GenerativeModel.swift index c2eb149..5e18ab2 100644 --- a/Sources/GoogleAI/GenerativeModel.swift +++ b/Sources/GoogleAI/GenerativeModel.swift @@ -39,6 +39,9 @@ public final class GenerativeModel { /// Tool configuration for any `Tool` specified in the request. let toolConfig: ToolConfig? + /// Instructions that direct the model to behave a certain way. + let systemInstruction: ModelContent? + /// Configuration parameters for sending requests to the backend. let requestOptions: RequestOptions @@ -51,6 +54,8 @@ public final class GenerativeModel { /// - generationConfig: The content generation parameters your model should use. /// - safetySettings: A value describing what types of harmful content your model should allow. /// - tools: A list of ``Tool`` objects that the model may use to generate the next response. + /// - systemInstruction: Instructions that direct the model to behave a certain way; currently + /// only text content is supported. /// - toolConfig: Tool configuration for any `Tool` specified in the request. /// - requestOptions Configuration parameters for sending requests to the backend. public convenience init(name: String, @@ -59,6 +64,7 @@ public final class GenerativeModel { safetySettings: [SafetySetting]? = nil, tools: [Tool]? = nil, toolConfig: ToolConfig? = nil, + systemInstruction: ModelContent? = nil, requestOptions: RequestOptions = RequestOptions()) { self.init( name: name, @@ -67,6 +73,7 @@ public final class GenerativeModel { safetySettings: safetySettings, tools: tools, toolConfig: toolConfig, + systemInstruction: systemInstruction, requestOptions: requestOptions, urlSession: .shared ) @@ -79,6 +86,7 @@ public final class GenerativeModel { safetySettings: [SafetySetting]? = nil, tools: [Tool]? = nil, toolConfig: ToolConfig? = nil, + systemInstruction: ModelContent? = nil, requestOptions: RequestOptions = RequestOptions(), urlSession: URLSession) { modelResourceName = GenerativeModel.modelResourceName(name: name) @@ -87,6 +95,7 @@ public final class GenerativeModel { self.safetySettings = safetySettings self.tools = tools self.toolConfig = toolConfig + self.systemInstruction = systemInstruction self.requestOptions = requestOptions Logging.default.info(""" @@ -134,6 +143,7 @@ public final class GenerativeModel { safetySettings: safetySettings, tools: tools, toolConfig: toolConfig, + systemInstruction: systemInstruction, isStreaming: false, options: requestOptions) response = try await generativeAIService.loadRequest(request: generateContentRequest) @@ -207,6 +217,7 @@ public final class GenerativeModel { safetySettings: safetySettings, tools: tools, toolConfig: toolConfig, + systemInstruction: systemInstruction, isStreaming: true, options: requestOptions) diff --git a/Tests/GoogleAITests/GoogleAITests.swift b/Tests/GoogleAITests/GoogleAITests.swift index cbc9252..c0764bb 100644 --- a/Tests/GoogleAITests/GoogleAITests.swift +++ b/Tests/GoogleAITests/GoogleAITests.swift @@ -30,17 +30,24 @@ final class GoogleGenerativeAITests: XCTestCase { maxOutputTokens: 256, stopSequences: ["..."]) let filters = [SafetySetting(harmCategory: .dangerousContent, threshold: .blockOnlyHigh)] + let systemInstruction = ModelContent(role: "system", parts: [.text("Talk like a pirate.")]) // Permutations without optional arguments. let _ = GenerativeModel(name: "gemini-1.0-pro", apiKey: "API_KEY") let _ = GenerativeModel(name: "gemini-1.0-pro", apiKey: "API_KEY", safetySettings: filters) let _ = GenerativeModel(name: "gemini-1.0-pro", apiKey: "API_KEY", generationConfig: config) + let _ = GenerativeModel( + name: "gemini-1.0-pro", + apiKey: "API_KEY", + systemInstruction: systemInstruction + ) // All arguments passed. let genAI = GenerativeModel(name: "gemini-1.0-pro", apiKey: "API_KEY", generationConfig: config, // Optional - safetySettings: filters // Optional + safetySettings: filters, // Optional + systemInstruction: systemInstruction // Optional ) // Full Typed Usage let pngData = Data() // ....