Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function calling support #116

Merged
merged 12 commits into from
Mar 26, 2024
Prev Previous commit
Next Next commit
Use static exchange rate data
andrewheard committed Mar 11, 2024
commit 44edbb1bab0d2e87a44395e865fcb76732afcaca
166 changes: 37 additions & 129 deletions Examples/GenerativeAICLI/Sources/GenerateContent.swift
Original file line number Diff line number Diff line change
@@ -81,12 +81,24 @@ struct GenerateContent: AsyncParsableCommand {
FunctionDeclaration(
name: "get_exchange_rate",
description: "Get the exchange rate for currencies between countries",
parameters: getExchangeRateSchema()
),
FunctionDeclaration(
name: "sum_integer_list",
description: "Sums a list of integer values",
parameters: sumIntegerListSchema()
parameters: Schema(
type: .object,
properties: [
"currency_from": Schema(
type: .string,
format: "enum",
description: "The currency to convert from in ISO 4217 format",
enumValues: ["USD", "EUR", "JPY", "GBP", "AUD", "CAD"]
),
"currency_to": Schema(
type: .string,
format: "enum",
description: "The currency to convert to in ISO 4217 format",
enumValues: ["USD", "EUR", "JPY", "GBP", "AUD", "CAD"]
),
],
required: ["currency_from", "currency_to"]
)
),
])],
requestOptions: RequestOptions(apiVersion: "v1beta")
@@ -161,17 +173,8 @@ struct GenerateContent: AsyncParsableCommand {
parts: [ModelContent.Part.functionCall(functionCall)]
))
switch functionCall.name {
case "sum_integer_list":
let sum = sumIntegerListWrapper(args: functionCall.args)
input.append(ModelContent(
role: "function",
parts: [ModelContent.Part.functionResponse(FunctionResponse(
name: "sum_integer_list",
response: sum
))]
))
case "get_exchange_rate":
let exchangeRates = try await getExchangeRateWrapper(args: functionCall.args)
let exchangeRates = getExchangeRate(args: functionCall.args)
input.append(ModelContent(
role: "function",
parts: [ModelContent.Part.functionResponse(FunctionResponse(
@@ -198,128 +201,33 @@ struct GenerateContent: AsyncParsableCommand {

// MARK: - Callable Functions

// Returns exchange rates from the Frankfurter API
// This is an example function that a developer might provide.
func getExchangeRate(amount: Double, date: String, from: String,
to: String) async throws -> String {
var urlComponents = URLComponents(string: "https://api.frankfurter.app")!
urlComponents.path = "/\(date)"
urlComponents.queryItems = [
.init(name: "amount", value: String(amount)),
.init(name: "from", value: from),
.init(name: "to", value: to),
]

let (data, _) = try await URLSession.shared.data(from: urlComponents.url!)
return String(data: data, encoding: .utf8)!
}

// This is a wrapper for the `getExchangeRate` function.
func getExchangeRateWrapper(args: JSONObject) async throws -> JSONObject {
func getExchangeRate(args: JSONObject) -> JSONObject {
// 1. Validate and extract the parameters provided by the model (from a `FunctionCall`)
guard case let .string(date) = args["currency_date"] else {
fatalError()
}
guard case let .string(from) = args["currency_from"] else {
fatalError()
fatalError("Missing `currency_from` parameter.")
}
guard case let .string(to) = args["currency_to"] else {
fatalError()
}
guard case let .number(amount) = args["amount"] else {
fatalError()
}

// 2. Call the wrapped function
let response = try await getExchangeRate(amount: amount, date: date, from: from, to: to)

// 3. Return the exchange rates as a JSON object (returned to the model in a `FunctionResponse`)
return ["content": .string(response)]
}

// Returns the schema of the `getExchangeRate` function
func getExchangeRateSchema() -> Schema {
return Schema(
type: .object,
properties: [
"currency_date": Schema(
type: .string,
description: """
A date that must always be in YYYY-MM-DD format or the value 'latest' if a time period
is not specified
"""
),
"currency_from": Schema(
type: .string,
description: "The currency to convert from in ISO 4217 format"
),
"currency_to": Schema(
type: .string,
description: "The currency to convert to in ISO 4217 format"
),
"amount": Schema(
type: .number,
description: "The amount of currency to convert as a double value"
),
],
required: ["currency_date", "currency_from", "currency_to", "amount"]
)
}

// Returns the sum of a list of integers.
// This is an example function that a developer could provide.
func sumIntegerList(_ integers: [Int]) -> Int {
var sum = 0
for integer in integers {
sum += integer
fatalError("Missing `currency_to` parameter.")
}
return sum
}

// This is a wrapper for the `sumIntegerList` function.
func sumIntegerListWrapper(args: JSONObject) -> JSONObject {
// 1. Validate and extract the parameters provided by the model (from a `FunctionCall`)
guard let values = args["values"] else {
fatalError("Expected a `values` parameter.")
}
guard case let .array(argArray) = values else {
fatalError("Expected `values` to be an array.")
// 2. Get the exchange rate
let allRates: [String: [String: Double]] = [
"AUD": ["CAD": 0.89265, "EUR": 0.6072, "GBP": 0.51714, "JPY": 97.75, "USD": 0.66379],
"CAD": ["AUD": 1.1203, "EUR": 0.68023, "GBP": 0.57933, "JPY": 109.51, "USD": 0.74362],
"EUR": ["AUD": 1.6469, "CAD": 1.4701, "GBP": 0.85168, "JPY": 160.99, "USD": 1.0932],
"GBP": ["AUD": 1.9337, "CAD": 1.7261, "EUR": 1.1741, "JPY": 189.03, "USD": 1.2836],
"JPY": ["AUD": 0.01023, "CAD": 0.00913, "EUR": 0.00621, "GBP": 0.00529, "USD": 0.00679],
"USD": ["AUD": 1.5065, "CAD": 1.3448, "EUR": 0.91475, "GBP": 0.77907, "JPY": 147.26],
]
guard let fromRates = allRates[from] else {
return ["error": .string("No data for currency \(from).")]
}

var integerArray = [Int]()
for arg in argArray {
guard case let .number(number) = arg else {
fatalError("Expected `values` array elements to be numbers.")
}
guard let integer = Int(exactly: number) else {
fatalError("Expected `values` array numbers to be integers.")
}
integerArray.append(integer)
guard let toRate = fromRates[to] else {
return ["error": .string("No data for currency \(to).")]
}

// 2. Call the wrapped function
let sum = sumIntegerList(integerArray)

// 3. Return the sum as a JSON object (to be returned to the model in a `FunctionResponse`)
return ["sum": .number(Double(sum))]
}

// Returns the schema of the `sumIntegerList` function.
func sumIntegerListSchema() -> Schema {
return Schema(
type: .object,
properties: [
"values": Schema(
type: .array,
description: "The integer values to sum",
items: Schema(
type: .integer,
format: "int64"
)
),
],
required: ["values"]
)
// 3. Return the exchange rates as a JSON object (returned to the model in a `FunctionResponse`)
return ["rates": .number(toRate)]
}
}

15 changes: 13 additions & 2 deletions Sources/GoogleAI/FunctionCalling.swift
Original file line number Diff line number Diff line change
@@ -43,6 +43,17 @@ public class Schema: Encodable {

let required: [String]?

enum CodingKeys: String, CodingKey {
case type
case format
case description
case nullable
case enumValues = "enum"
case items
case properties
case required
}

public init(type: DataType, format: String? = nil, description: String? = nil,
nullable: Bool? = nil,
enumValues: [String]? = nil, items: Schema? = nil,
@@ -75,9 +86,9 @@ public struct FunctionDeclaration {

let description: String

let parameters: Schema
let parameters: Schema?

public init(name: String, description: String, parameters: Schema) {
public init(name: String, description: String, parameters: Schema?) {
self.name = name
self.description = description
self.parameters = parameters