From 1ed073fcd343cfe3efda88f51f29ea13ee6e924f Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Mon, 11 Mar 2024 12:09:11 -0400 Subject: [PATCH] Automatic function calling prototype --- .../Sources/GenerateContent.swift | 88 ++++++++++++++- Sources/GoogleAI/Chat.swift | 28 ++++- Sources/GoogleAI/FunctionCalling.swift | 104 +++++++++++++++++- Sources/GoogleAI/GenerateContentRequest.swift | 2 + Sources/GoogleAI/GenerativeModel.swift | 34 ++++++ Sources/GoogleAI/ModelContent.swift | 10 +- 6 files changed, 256 insertions(+), 10 deletions(-) diff --git a/Examples/GenerativeAICLI/Sources/GenerateContent.swift b/Examples/GenerativeAICLI/Sources/GenerateContent.swift index ab71c43..488b090 100644 --- a/Examples/GenerativeAICLI/Sources/GenerateContent.swift +++ b/Examples/GenerativeAICLI/Sources/GenerateContent.swift @@ -70,9 +70,20 @@ struct GenerateContent: AsyncParsableCommand { name: modelNameOrDefault(), apiKey: apiKey, generationConfig: config, - safetySettings: safetySettings + safetySettings: safetySettings, + tools: [Tool(functionDeclarations: [ + FunctionDeclaration( + name: "get_exchange_rate", + description: "Get the exchange rate for currencies between countries", + parameters: getExchangeRateSchema(), + function: getExchangeRateWrapper + ), + ])], + requestOptions: RequestOptions(apiVersion: "v1beta") ) + let chat = model.startChat() + var parts = [ModelContent.Part]() if let textPrompt = textPrompt { @@ -96,7 +107,7 @@ struct GenerateContent: AsyncParsableCommand { let input = [ModelContent(parts: parts)] if isStreaming { - let contentStream = model.generateContentStream(input) + let contentStream = chat.sendMessageStream(input) print("Generated Content :") for try await content in contentStream { if let text = content.text { @@ -104,7 +115,8 @@ struct GenerateContent: AsyncParsableCommand { } } } else { - let content = try await model.generateContent(input) + // Unary generate content + let content = try await chat.sendMessage(input) if let text = content.text { print("Generated Content:\n\(text)") } @@ -123,6 +135,76 @@ struct GenerateContent: AsyncParsableCommand { return "gemini-1.0-pro" } } + + // MARK: - Callable Functions + + // Returns exchange rates from the Frankfurter API + // This is an example function that a developer might provide. + func getExchangeRate(amount: Double, date: String, from: String, + to: String) async throws -> String { + var urlComponents = URLComponents(string: "https://api.frankfurter.app")! + urlComponents.path = "/\(date)" + urlComponents.queryItems = [ + .init(name: "amount", value: String(amount)), + .init(name: "from", value: from), + .init(name: "to", value: to), + ] + + let (data, _) = try await URLSession.shared.data(from: urlComponents.url!) + return String(data: data, encoding: .utf8)! + } + + // This is a wrapper for the `getExchangeRate` function. + func getExchangeRateWrapper(args: JSONObject) async throws -> JSONObject { + // 1. Validate and extract the parameters provided by the model (from a `FunctionCall`) + guard case let .string(date) = args["currency_date"] else { + fatalError() + } + guard case let .string(from) = args["currency_from"] else { + fatalError() + } + guard case let .string(to) = args["currency_to"] else { + fatalError() + } + guard case let .number(amount) = args["amount"] else { + fatalError() + } + + // 2. Call the wrapped function + let response = try await getExchangeRate(amount: amount, date: date, from: from, to: to) + + // 3. Return the exchange rates as a JSON object (returned to the model in a `FunctionResponse`) + return ["content": .string(response)] + } + + // Returns the schema of the `getExchangeRate` function + func getExchangeRateSchema() -> Schema { + return Schema( + type: .object, + properties: [ + "currency_date": Schema( + type: .string, + description: """ + A date that must always be in YYYY-MM-DD format or the value 'latest' if a time period + is not specified + """ + ), + "currency_from": Schema( + type: .string, + description: "The currency to convert from in ISO 4217 format" + ), + "currency_to": Schema( + type: .string, + description: "The currency to convert to in ISO 4217 format" + ), + "amount": Schema( + type: .number, + description: "The amount of currency to convert as a double value" + ), + ], + required: ["currency_date", "currency_from", "currency_to", "amount"] + ) + } } enum CLIError: Error { diff --git a/Sources/GoogleAI/Chat.swift b/Sources/GoogleAI/Chat.swift index 6e7e885..74fdebd 100644 --- a/Sources/GoogleAI/Chat.swift +++ b/Sources/GoogleAI/Chat.swift @@ -70,10 +70,32 @@ public class Chat { // Make sure we inject the role into the content received. let toAdd = ModelContent(role: "model", parts: reply.parts) + var functionResponses = [FunctionResponse]() + for part in reply.parts { + if case let .functionCall(functionCall) = part { + try functionResponses.append(await model.executeFunction(functionCall: functionCall)) + } + } + + // Call the functions requested by the model, if any. + let functionResponseContent = try ModelContent( + role: "function", + functionResponses.map { functionResponse in + ModelContent.Part.functionResponse(functionResponse) + } + ) + // Append the request and successful result to history, then return the value. history.append(contentsOf: newContent) history.append(toAdd) - return result + + // If no function calls requested, return the results. + if functionResponses.isEmpty { + return result + } + + // Re-send the message with the function responses. + return try await sendMessage([functionResponseContent]) } /// See ``sendMessageStream(_:)-4abs3``. @@ -166,6 +188,10 @@ public class Chat { case .functionCall: // TODO(andrewheard): Add function call to the chat history when encoding is implemented. fatalError("Function calling not yet implemented in chat.") + + case .functionResponse: + // TODO(andrewheard): Add function response to chat history when encoding is implemented. + fatalError("Function calling not yet implemented in chat.") } } } diff --git a/Sources/GoogleAI/FunctionCalling.swift b/Sources/GoogleAI/FunctionCalling.swift index 5d8ded5..e76e434 100644 --- a/Sources/GoogleAI/FunctionCalling.swift +++ b/Sources/GoogleAI/FunctionCalling.swift @@ -15,14 +15,97 @@ import Foundation /// A predicted function call returned from the model. -public struct FunctionCall: Equatable { +/// +/// REST Docs: https://ai.google.dev/api/rest/v1beta/Content#functioncall +public struct FunctionCall: Equatable, Encodable { /// The name of the function to call. - let name: String + public let name: String /// The function parameters and values. - let args: JSONObject + public let args: JSONObject } +// REST Docs: https://ai.google.dev/api/rest/v1beta/Tool#schema +public class Schema: Encodable { + let type: DataType + + let format: String? + + let description: String? + + let nullable: Bool? + + let enumValues: [String]? + + let items: Schema? + + let properties: [String: Schema]? + + let required: [String]? + + public init(type: DataType, format: String? = nil, description: String? = nil, + nullable: Bool? = nil, + enumValues: [String]? = nil, items: Schema? = nil, + properties: [String: Schema]? = nil, + required: [String]? = nil) { + self.type = type + self.format = format + self.description = description + self.nullable = nullable + self.enumValues = enumValues + self.items = items + self.properties = properties + self.required = required + } +} + +// REST Docs: https://ai.google.dev/api/rest/v1beta/Tool#Type +public enum DataType: String, Encodable { + case string = "STRING" + case number = "NUMBER" + case integer = "INTEGER" + case boolean = "BOOLEAN" + case array = "ARRAY" + case object = "OBJECT" +} + +// REST Docs: https://ai.google.dev/api/rest/v1beta/Tool#FunctionDeclaration +public struct FunctionDeclaration { + let name: String + + let description: String + + let parameters: Schema + + let function: ((JSONObject) async throws -> JSONObject)? + + public init(name: String, description: String, parameters: Schema, + function: ((JSONObject) async throws -> JSONObject)?) { + self.name = name + self.description = description + self.parameters = parameters + self.function = function + } +} + +// REST Docs: https://ai.google.dev/api/rest/v1beta/Tool +public struct Tool: Encodable { + let functionDeclarations: [FunctionDeclaration]? + + public init(functionDeclarations: [FunctionDeclaration]?) { + self.functionDeclarations = functionDeclarations + } +} + +// REST Docs: https://ai.google.dev/api/rest/v1beta/Content#functionresponse +public struct FunctionResponse: Equatable, Encodable { + let name: String + + let response: JSONObject +} + +// MARK: - Codable Conformance + extension FunctionCall: Decodable { enum CodingKeys: CodingKey { case name @@ -39,3 +122,18 @@ extension FunctionCall: Decodable { } } } + +extension FunctionDeclaration: Encodable { + enum CodingKeys: String, CodingKey { + case name + case description + case parameters + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(name, forKey: .name) + try container.encode(description, forKey: .description) + try container.encode(parameters, forKey: .parameters) + } +} diff --git a/Sources/GoogleAI/GenerateContentRequest.swift b/Sources/GoogleAI/GenerateContentRequest.swift index 417260b..535ac4d 100644 --- a/Sources/GoogleAI/GenerateContentRequest.swift +++ b/Sources/GoogleAI/GenerateContentRequest.swift @@ -21,6 +21,7 @@ struct GenerateContentRequest { let contents: [ModelContent] let generationConfig: GenerationConfig? let safetySettings: [SafetySetting]? + let tools: [Tool]? let isStreaming: Bool let options: RequestOptions } @@ -31,6 +32,7 @@ extension GenerateContentRequest: Encodable { case contents case generationConfig case safetySettings + case tools } } diff --git a/Sources/GoogleAI/GenerativeModel.swift b/Sources/GoogleAI/GenerativeModel.swift index 03e0191..39fdbb8 100644 --- a/Sources/GoogleAI/GenerativeModel.swift +++ b/Sources/GoogleAI/GenerativeModel.swift @@ -36,6 +36,8 @@ public final class GenerativeModel { /// The safety settings to be used for prompts. let safetySettings: [SafetySetting]? + let tools: [Tool]? + /// Configuration parameters for sending requests to the backend. let requestOptions: RequestOptions @@ -52,12 +54,14 @@ public final class GenerativeModel { apiKey: String, generationConfig: GenerationConfig? = nil, safetySettings: [SafetySetting]? = nil, + tools: [Tool]? = nil, requestOptions: RequestOptions = RequestOptions()) { self.init( name: name, apiKey: apiKey, generationConfig: generationConfig, safetySettings: safetySettings, + tools: tools, requestOptions: requestOptions, urlSession: .shared ) @@ -68,12 +72,14 @@ public final class GenerativeModel { apiKey: String, generationConfig: GenerationConfig? = nil, safetySettings: [SafetySetting]? = nil, + tools: [Tool]? = nil, requestOptions: RequestOptions = RequestOptions(), urlSession: URLSession) { modelResourceName = GenerativeModel.modelResourceName(name: name) generativeAIService = GenerativeAIService(apiKey: apiKey, urlSession: urlSession) self.generationConfig = generationConfig self.safetySettings = safetySettings + self.tools = tools self.requestOptions = requestOptions Logging.default.info(""" @@ -119,6 +125,7 @@ public final class GenerativeModel { contents: content(), generationConfig: generationConfig, safetySettings: safetySettings, + tools: tools, isStreaming: false, options: requestOptions) response = try await generativeAIService.loadRequest(request: generateContentRequest) @@ -190,6 +197,7 @@ public final class GenerativeModel { contents: evaluatedContent, generationConfig: generationConfig, safetySettings: safetySettings, + tools: tools, isStreaming: true, options: requestOptions) @@ -270,6 +278,30 @@ public final class GenerativeModel { } } + func executeFunction(functionCall: FunctionCall) async throws -> FunctionResponse { + guard let tools = tools else { + throw GenerateContentError.internalError(underlying: FunctionCallError()) + } + guard let tool = tools.first(where: { tool in + tool.functionDeclarations != nil + }) else { + throw GenerateContentError.internalError(underlying: FunctionCallError()) + } + guard let functionDeclaration = tool.functionDeclarations?.first(where: { functionDeclaration in + functionDeclaration.name == functionCall.name + }) else { + throw GenerateContentError.internalError(underlying: FunctionCallError()) + } + guard let function = functionDeclaration.function else { + throw GenerateContentError.internalError(underlying: FunctionCallError()) + } + + return try FunctionResponse( + name: functionCall.name, + response: await function(functionCall.args) + ) + } + /// Returns a model resource name of the form "models/model-name" based on `name`. private static func modelResourceName(name: String) -> String { if name.contains("/") { @@ -299,3 +331,5 @@ public final class GenerativeModel { public enum CountTokensError: Error { case internalError(underlying: Error) } + +struct FunctionCallError: Error {} diff --git a/Sources/GoogleAI/ModelContent.swift b/Sources/GoogleAI/ModelContent.swift index 2ce8876..4aefe7b 100644 --- a/Sources/GoogleAI/ModelContent.swift +++ b/Sources/GoogleAI/ModelContent.swift @@ -26,6 +26,7 @@ public struct ModelContent: Codable, Equatable { case text case inlineData case functionCall + case functionResponse } enum InlineDataKeys: String, CodingKey { @@ -42,6 +43,8 @@ public struct ModelContent: Codable, Equatable { /// A predicted function call returned from the model. case functionCall(FunctionCall) + case functionResponse(FunctionResponse) + // MARK: Convenience Initializers /// Convenience function for populating a Part with JPEG data. @@ -68,9 +71,10 @@ public struct ModelContent: Codable, Equatable { ) try inlineDataContainer.encode(mimetype, forKey: .mimeType) try inlineDataContainer.encode(bytes, forKey: .bytes) - case .functionCall: - // TODO(andrewheard): Encode FunctionCalls when when encoding is implemented. - fatalError("FunctionCall encoding not implemented.") + case let .functionCall(functionCall): + try container.encode(functionCall, forKey: .functionCall) + case let .functionResponse(functionResponse): + try container.encode(functionResponse, forKey: .functionResponse) } }