Skip to content

Commit

Permalink
Automatic function calling prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard committed Mar 12, 2024
1 parent 667eccf commit 1ed073f
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 10 deletions.
88 changes: 85 additions & 3 deletions Examples/GenerativeAICLI/Sources/GenerateContent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,20 @@ struct GenerateContent: AsyncParsableCommand {
name: modelNameOrDefault(),
apiKey: apiKey,
generationConfig: config,
safetySettings: safetySettings
safetySettings: safetySettings,
tools: [Tool(functionDeclarations: [
FunctionDeclaration(
name: "get_exchange_rate",
description: "Get the exchange rate for currencies between countries",
parameters: getExchangeRateSchema(),
function: getExchangeRateWrapper
),
])],
requestOptions: RequestOptions(apiVersion: "v1beta")
)

let chat = model.startChat()

var parts = [ModelContent.Part]()

if let textPrompt = textPrompt {
Expand All @@ -96,15 +107,16 @@ struct GenerateContent: AsyncParsableCommand {
let input = [ModelContent(parts: parts)]

if isStreaming {
let contentStream = model.generateContentStream(input)
let contentStream = chat.sendMessageStream(input)
print("Generated Content <streaming>:")
for try await content in contentStream {
if let text = content.text {
print(text)
}
}
} else {
let content = try await model.generateContent(input)
// Unary generate content
let content = try await chat.sendMessage(input)
if let text = content.text {
print("Generated Content:\n\(text)")
}
Expand All @@ -123,6 +135,76 @@ struct GenerateContent: AsyncParsableCommand {
return "gemini-1.0-pro"
}
}

// MARK: - Callable Functions

// Returns exchange rates from the Frankfurter API
// This is an example function that a developer might provide.
func getExchangeRate(amount: Double, date: String, from: String,
to: String) async throws -> String {
var urlComponents = URLComponents(string: "https://api.frankfurter.app")!
urlComponents.path = "/\(date)"
urlComponents.queryItems = [
.init(name: "amount", value: String(amount)),
.init(name: "from", value: from),
.init(name: "to", value: to),
]

let (data, _) = try await URLSession.shared.data(from: urlComponents.url!)
return String(data: data, encoding: .utf8)!
}

// This is a wrapper for the `getExchangeRate` function.
func getExchangeRateWrapper(args: JSONObject) async throws -> JSONObject {
// 1. Validate and extract the parameters provided by the model (from a `FunctionCall`)
guard case let .string(date) = args["currency_date"] else {
fatalError()
}
guard case let .string(from) = args["currency_from"] else {
fatalError()
}
guard case let .string(to) = args["currency_to"] else {
fatalError()
}
guard case let .number(amount) = args["amount"] else {
fatalError()
}

// 2. Call the wrapped function
let response = try await getExchangeRate(amount: amount, date: date, from: from, to: to)

// 3. Return the exchange rates as a JSON object (returned to the model in a `FunctionResponse`)
return ["content": .string(response)]
}

// Returns the schema of the `getExchangeRate` function
func getExchangeRateSchema() -> Schema {
return Schema(
type: .object,
properties: [
"currency_date": Schema(
type: .string,
description: """
A date that must always be in YYYY-MM-DD format or the value 'latest' if a time period
is not specified
"""
),
"currency_from": Schema(
type: .string,
description: "The currency to convert from in ISO 4217 format"
),
"currency_to": Schema(
type: .string,
description: "The currency to convert to in ISO 4217 format"
),
"amount": Schema(
type: .number,
description: "The amount of currency to convert as a double value"
),
],
required: ["currency_date", "currency_from", "currency_to", "amount"]
)
}
}

enum CLIError: Error {
Expand Down
28 changes: 27 additions & 1 deletion Sources/GoogleAI/Chat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,32 @@ public class Chat {
// Make sure we inject the role into the content received.
let toAdd = ModelContent(role: "model", parts: reply.parts)

var functionResponses = [FunctionResponse]()
for part in reply.parts {
if case let .functionCall(functionCall) = part {
try functionResponses.append(await model.executeFunction(functionCall: functionCall))
}
}

// Call the functions requested by the model, if any.
let functionResponseContent = try ModelContent(
role: "function",
functionResponses.map { functionResponse in
ModelContent.Part.functionResponse(functionResponse)
}
)

// Append the request and successful result to history, then return the value.
history.append(contentsOf: newContent)
history.append(toAdd)
return result

// If no function calls requested, return the results.
if functionResponses.isEmpty {
return result
}

// Re-send the message with the function responses.
return try await sendMessage([functionResponseContent])
}

/// See ``sendMessageStream(_:)-4abs3``.
Expand Down Expand Up @@ -166,6 +188,10 @@ public class Chat {
case .functionCall:
// TODO(andrewheard): Add function call to the chat history when encoding is implemented.
fatalError("Function calling not yet implemented in chat.")

case .functionResponse:
// TODO(andrewheard): Add function response to chat history when encoding is implemented.
fatalError("Function calling not yet implemented in chat.")
}
}
}
Expand Down
104 changes: 101 additions & 3 deletions Sources/GoogleAI/FunctionCalling.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,97 @@
import Foundation

/// A predicted function call returned from the model.
public struct FunctionCall: Equatable {
///
/// REST Docs: https://ai.google.dev/api/rest/v1beta/Content#functioncall
public struct FunctionCall: Equatable, Encodable {
/// The name of the function to call.
let name: String
public let name: String

/// The function parameters and values.
let args: JSONObject
public let args: JSONObject
}

// REST Docs: https://ai.google.dev/api/rest/v1beta/Tool#schema
public class Schema: Encodable {
let type: DataType

let format: String?

let description: String?

let nullable: Bool?

let enumValues: [String]?

let items: Schema?

let properties: [String: Schema]?

let required: [String]?

public init(type: DataType, format: String? = nil, description: String? = nil,
nullable: Bool? = nil,
enumValues: [String]? = nil, items: Schema? = nil,
properties: [String: Schema]? = nil,
required: [String]? = nil) {
self.type = type
self.format = format
self.description = description
self.nullable = nullable
self.enumValues = enumValues
self.items = items
self.properties = properties
self.required = required
}
}

// REST Docs: https://ai.google.dev/api/rest/v1beta/Tool#Type
public enum DataType: String, Encodable {
case string = "STRING"
case number = "NUMBER"
case integer = "INTEGER"
case boolean = "BOOLEAN"
case array = "ARRAY"
case object = "OBJECT"
}

// REST Docs: https://ai.google.dev/api/rest/v1beta/Tool#FunctionDeclaration
public struct FunctionDeclaration {
let name: String

let description: String

let parameters: Schema

let function: ((JSONObject) async throws -> JSONObject)?

public init(name: String, description: String, parameters: Schema,
function: ((JSONObject) async throws -> JSONObject)?) {
self.name = name
self.description = description
self.parameters = parameters
self.function = function
}
}

// REST Docs: https://ai.google.dev/api/rest/v1beta/Tool
public struct Tool: Encodable {
let functionDeclarations: [FunctionDeclaration]?

public init(functionDeclarations: [FunctionDeclaration]?) {
self.functionDeclarations = functionDeclarations
}
}

// REST Docs: https://ai.google.dev/api/rest/v1beta/Content#functionresponse
public struct FunctionResponse: Equatable, Encodable {
let name: String

let response: JSONObject
}

// MARK: - Codable Conformance

extension FunctionCall: Decodable {
enum CodingKeys: CodingKey {
case name
Expand All @@ -39,3 +122,18 @@ extension FunctionCall: Decodable {
}
}
}

extension FunctionDeclaration: Encodable {
enum CodingKeys: String, CodingKey {
case name
case description
case parameters
}

public func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
try container.encode(name, forKey: .name)
try container.encode(description, forKey: .description)
try container.encode(parameters, forKey: .parameters)
}
}
2 changes: 2 additions & 0 deletions Sources/GoogleAI/GenerateContentRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ struct GenerateContentRequest {
let contents: [ModelContent]
let generationConfig: GenerationConfig?
let safetySettings: [SafetySetting]?
let tools: [Tool]?
let isStreaming: Bool
let options: RequestOptions
}
Expand All @@ -31,6 +32,7 @@ extension GenerateContentRequest: Encodable {
case contents
case generationConfig
case safetySettings
case tools
}
}

Expand Down
34 changes: 34 additions & 0 deletions Sources/GoogleAI/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ public final class GenerativeModel {
/// The safety settings to be used for prompts.
let safetySettings: [SafetySetting]?

let tools: [Tool]?

/// Configuration parameters for sending requests to the backend.
let requestOptions: RequestOptions

Expand All @@ -52,12 +54,14 @@ public final class GenerativeModel {
apiKey: String,
generationConfig: GenerationConfig? = nil,
safetySettings: [SafetySetting]? = nil,
tools: [Tool]? = nil,
requestOptions: RequestOptions = RequestOptions()) {
self.init(
name: name,
apiKey: apiKey,
generationConfig: generationConfig,
safetySettings: safetySettings,
tools: tools,
requestOptions: requestOptions,
urlSession: .shared
)
Expand All @@ -68,12 +72,14 @@ public final class GenerativeModel {
apiKey: String,
generationConfig: GenerationConfig? = nil,
safetySettings: [SafetySetting]? = nil,
tools: [Tool]? = nil,
requestOptions: RequestOptions = RequestOptions(),
urlSession: URLSession) {
modelResourceName = GenerativeModel.modelResourceName(name: name)
generativeAIService = GenerativeAIService(apiKey: apiKey, urlSession: urlSession)
self.generationConfig = generationConfig
self.safetySettings = safetySettings
self.tools = tools
self.requestOptions = requestOptions

Logging.default.info("""
Expand Down Expand Up @@ -119,6 +125,7 @@ public final class GenerativeModel {
contents: content(),
generationConfig: generationConfig,
safetySettings: safetySettings,
tools: tools,
isStreaming: false,
options: requestOptions)
response = try await generativeAIService.loadRequest(request: generateContentRequest)
Expand Down Expand Up @@ -190,6 +197,7 @@ public final class GenerativeModel {
contents: evaluatedContent,
generationConfig: generationConfig,
safetySettings: safetySettings,
tools: tools,
isStreaming: true,
options: requestOptions)

Expand Down Expand Up @@ -270,6 +278,30 @@ public final class GenerativeModel {
}
}

func executeFunction(functionCall: FunctionCall) async throws -> FunctionResponse {
guard let tools = tools else {
throw GenerateContentError.internalError(underlying: FunctionCallError())
}
guard let tool = tools.first(where: { tool in
tool.functionDeclarations != nil
}) else {
throw GenerateContentError.internalError(underlying: FunctionCallError())
}
guard let functionDeclaration = tool.functionDeclarations?.first(where: { functionDeclaration in
functionDeclaration.name == functionCall.name
}) else {
throw GenerateContentError.internalError(underlying: FunctionCallError())
}
guard let function = functionDeclaration.function else {
throw GenerateContentError.internalError(underlying: FunctionCallError())
}

return try FunctionResponse(
name: functionCall.name,
response: await function(functionCall.args)
)
}

/// Returns a model resource name of the form "models/model-name" based on `name`.
private static func modelResourceName(name: String) -> String {
if name.contains("/") {
Expand Down Expand Up @@ -299,3 +331,5 @@ public final class GenerativeModel {
public enum CountTokensError: Error {
case internalError(underlying: Error)
}

struct FunctionCallError: Error {}
Loading

0 comments on commit 1ed073f

Please sign in to comment.