Skip to content

Commit

Permalink
feat: update to use tools api
Browse files Browse the repository at this point in the history
  • Loading branch information
stillmatic committed Nov 16, 2023
1 parent e7f7664 commit 7c4cc75
Show file tree
Hide file tree
Showing 10 changed files with 104 additions and 67 deletions.
35 changes: 18 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,28 +78,24 @@ type getWeatherInput struct {

fi := gollum.StructToJsonSchema("weather", "Get the current weather in a given location", getWeatherInput{})

chatRequest := chatCompletionRequest{
ChatCompletionRequest: openai.ChatCompletionRequest{
Model: "gpt-3.5-turbo-0613",
Messages: []openai.ChatCompletionMessage{
{
Role: "user",
Content: "Whats the temperature in Boston?",
},
chatRequest := openai.ChatCompletionRequest{
Model: "gpt-3.5-turbo-0613",
Messages: []openai.ChatCompletionMessage{
{
Role: "user",
Content: "Whats the temperature in Boston?",
},
MaxTokens: 512,
Temperature: 0.0,
},
Functions: []gollum.FunctionInput{
fi,
},
FunctionCall: "auto",
MaxTokens: 256,
Temperature: 0.0,
Tools: []openai.Tool{{Type: "function", Function: openai.FunctionDefinition(fi)}},
ToolChoice: "weather",
}

ctx := context.Background()
resp, err := api.SendRequest(ctx, chatRequest)
parser := gollum.NewJSONParser[getWeatherInput](false)
input, err := parser.Parse(ctx, resp.Choices[0].Message.FunctionCall.Arguments)
input, err := parser.Parse(ctx, resp.Choices[0].Message.ToolCalls[0].Function.Arguments)
```

This example steps through all that, end to end. Some of this is 'sort of' pseudo-code, as the OpenAI clients I use haven't implemented support yet for functions, but it should also hopefully show that minimal modifications are necessary to upstream libraries.
Expand Down Expand Up @@ -131,10 +127,15 @@ chatRequest := chatCompletionRequest{
MaxTokens: 256,
Temperature: 0.0,
},
Functions: []gollum.FunctionInput{fi},
Tools: []openai.Tool{
{
Type: "function",
Function: fi,
}
}
}
parser := gollum.NewJSONParser[openai.ChatCompletionRequest](false)
input, err := parser.Parse(ctx, resp.Choices[0].Message.FunctionCall.Arguments)
input, err := parser.Parse(ctx, resp.Choices[0].Message.ToolCalls[0].Function.Arguments)
```

On the first try, this yielded the following result:
Expand Down
15 changes: 7 additions & 8 deletions agents/calcagent.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,6 @@ func (c *CalcAgent) Description() string {
return "convert natural language and evaluate mathematical expressions"
}

type functionCall struct {
Name string `json:"name"`
}

func (c *CalcAgent) Run(ctx context.Context, input interface{}) (interface{}, error) {
cinput, ok := input.(CalcAgentInput)
if !ok {
Expand All @@ -65,15 +61,18 @@ func (c *CalcAgent) Run(ctx context.Context, input interface{}) (interface{}, er
Content: cinput.Content,
},
},
MaxTokens: 128,
Functions: []openai.FunctionDefinition{c.functionInput},
FunctionCall: functionCall{Name: "calculator"},
MaxTokens: 128,
Tools: []openai.Tool{{
Type: "function",
Function: c.functionInput,
}},
ToolChoice: "calculator",
})
if err != nil {
return "", errors.Wrap(err, "couldn't call the LLM")
}
// parse response
parsed, err := c.parser.Parse(ctx, []byte(resp.Choices[0].Message.FunctionCall.Arguments))
parsed, err := c.parser.Parse(ctx, []byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments))
if err != nil {
return "", errors.Wrap(err, "couldn't parse response")
}
Expand Down
8 changes: 5 additions & 3 deletions agents/calcagent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ func TestCalcAgentMocked(t *testing.T) {
Choices: []openai.ChatCompletionChoice{
{
Message: openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
Content: "",
FunctionCall: &openai.FunctionCall{Name: "calc", Arguments: string(expectedBytes)},
Role: openai.ChatMessageRoleAssistant,
Content: "",
ToolCalls: []openai.ToolCall{
{Function: openai.FunctionCall{Name: "calc", Arguments: string(expectedBytes)}},
},
},
},
},
Expand Down
13 changes: 6 additions & 7 deletions dispatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type OpenAIDispatcherConfig struct {
type OpenAIDispatcher[T any] struct {
*OpenAIDispatcherConfig
completer ChatCompleter
fi openai.FunctionDefinition
ti openai.Tool
parser Parser[T]
}

Expand All @@ -58,11 +58,12 @@ func NewOpenAIDispatcher[T any](name, description string, completer ChatComplete
// we won't check here but the openai client will throw an error
var t T
fi := StructToJsonSchema(name, description, t)
ti := FunctionInputToTool(fi)
parser := NewJSONParserGeneric[T](true)
return &OpenAIDispatcher[T]{
OpenAIDispatcherConfig: cfg,
completer: completer,
fi: openai.FunctionDefinition(fi),
ti: ti,
parser: parser,
}
}
Expand Down Expand Up @@ -92,18 +93,16 @@ func (d *OpenAIDispatcher[T]) Prompt(ctx context.Context, prompt string) (T, err
Content: prompt,
},
},
Functions: []openai.FunctionDefinition{d.fi},
FunctionCall: struct {
Name string `json:"name"`
}{Name: d.fi.Name},
Tools: []openai.Tool{d.ti},
ToolChoice: d.ti.Function.Name,
Temperature: temperature,
MaxTokens: maxTokens,
})
if err != nil {
return output, err
}

output, err = d.parser.Parse(ctx, []byte(resp.Choices[0].Message.FunctionCall.Arguments))
output, err = d.parser.Parse(ctx, []byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments))
if err != nil {
return output, err
}
Expand Down
31 changes: 19 additions & 12 deletions dispatch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ func TestOpenAIDispatcher(t *testing.T) {
inpStr := `{"topic": "dinosaurs", "random_words": ["dinosaur", "fossil", "extinct"]}`

fi := openai.FunctionDefinition(gollum.StructToJsonSchema("random_conversation", "Given a topic, return random words", testInput{}))
ti := openai.Tool{Type: "function", Function: fi}
expectedRequest := openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo0613,
Messages: []openai.ChatCompletionMessage{
Expand All @@ -64,10 +65,8 @@ func TestOpenAIDispatcher(t *testing.T) {
Content: "Tell me about dinosaurs",
},
},
Functions: []openai.FunctionDefinition{fi},
FunctionCall: struct {
Name string `json:"name"`
}{Name: fi.Name},
Tools: []openai.Tool{ti},
ToolChoice: fi.Name,
MaxTokens: 512,
Temperature: 0.0,
}
Expand All @@ -80,10 +79,14 @@ func TestOpenAIDispatcher(t *testing.T) {
Message: openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleSystem,
Content: "Hello there!",
FunctionCall: &openai.FunctionCall{
Name: "random_conversation",
Arguments: inpStr,
},
ToolCalls: []openai.ToolCall{
{
Type: "function",
Function: openai.FunctionCall{
Name: "random_conversation",
Arguments: inpStr,
},
}},
},
},
},
Expand All @@ -101,10 +104,14 @@ func TestOpenAIDispatcher(t *testing.T) {
Message: openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleSystem,
Content: "Hello there!",
FunctionCall: &openai.FunctionCall{
Name: "random_conversation",
Arguments: inpStr,
},
ToolCalls: []openai.ToolCall{
{
Type: "function",
Function: openai.FunctionCall{
Name: "random_conversation",
Arguments: inpStr,
},
}},
},
},
},
Expand Down
14 changes: 14 additions & 0 deletions functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"reflect"

"github.com/invopop/jsonschema"
"github.com/sashabaranov/go-openai"
)

type FunctionInput struct {
Expand All @@ -12,6 +13,19 @@ type FunctionInput struct {
Parameters any `json:"parameters"`
}

type OAITool struct {
// Type is always "function" for now.
Type string `json:"type"`
Function FunctionInput `json:"function"`
}

func FunctionInputToTool(fi FunctionInput) openai.Tool {
return openai.Tool{
Type: "function",
Function: openai.FunctionDefinition(fi),
}
}

func StructToJsonSchema(functionName string, functionDescription string, inputStruct interface{}) FunctionInput {
t := reflect.TypeOf(inputStruct)
schema := jsonschema.ReflectFromType(reflect.Type(t))
Expand Down
39 changes: 23 additions & 16 deletions functions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ func TestEndToEnd(t *testing.T) {
},
MaxTokens: 256,
Temperature: 0.0,
Functions: []openai.FunctionDefinition{openai.FunctionDefinition(fi)},
Tools: []openai.Tool{{Type: "function", Function: openai.FunctionDefinition(fi)}},
ToolChoice: "weather",
}

ctx := context.Background()
Expand All @@ -91,15 +92,15 @@ func TestEndToEnd(t *testing.T) {
assert.Equal(t, resp.Model, "gpt-3.5-turbo-0613")
assert.NotEmpty(t, resp.Choices)
assert.Empty(t, resp.Choices[0].Message.Content)
assert.NotNil(t, resp.Choices[0].Message.FunctionCall)
assert.Equal(t, resp.Choices[0].Message.FunctionCall.Name, "weather")
assert.NotNil(t, resp.Choices[0].Message.ToolCalls)
assert.Equal(t, resp.Choices[0].Message.ToolCalls[0].Function.Name, "weather")

// this is somewhat flaky - about 20% of the time it returns 'Boston'
expectedArg := []byte(`{"location": "Boston, MA"}`)
parser := gollum.NewJSONParserGeneric[getWeatherInput](false)
expectedStruct, err := parser.Parse(ctx, expectedArg)
assert.NoError(t, err)
input, err := parser.Parse(ctx, []byte(resp.Choices[0].Message.FunctionCall.Arguments))
input, err := parser.Parse(ctx, []byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments))
assert.NoError(t, err)
assert.Equal(t, expectedStruct, input)
})
Expand All @@ -116,8 +117,8 @@ func TestEndToEnd(t *testing.T) {
},
MaxTokens: 256,
Temperature: 0.0,
Functions: []openai.FunctionDefinition{
openai.FunctionDefinition(fi),
Tools: []openai.Tool{
{Type: "function", Function: openai.FunctionDefinition(fi)},
},
}
ctx := context.Background()
Expand All @@ -127,15 +128,15 @@ func TestEndToEnd(t *testing.T) {
assert.Equal(t, resp.Model, "gpt-3.5-turbo-0613")
assert.NotEmpty(t, resp.Choices)
assert.Empty(t, resp.Choices[0].Message.Content)
assert.NotNil(t, resp.Choices[0].Message.FunctionCall)
assert.Equal(t, resp.Choices[0].Message.FunctionCall.Name, "split_word")
assert.NotNil(t, resp.Choices[0].Message.ToolCalls)
assert.Equal(t, resp.Choices[0].Message.ToolCalls[0].Function.Name, "split_word")

expectedStruct := counter{
Count: 7,
Words: []string{"What", "is", "the", "weather", "like", "in", "Boston?"},
}
parser := gollum.NewJSONParserGeneric[counter](false)
input, err := parser.Parse(ctx, []byte(resp.Choices[0].Message.FunctionCall.Arguments))
input, err := parser.Parse(ctx, []byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments))
assert.NoError(t, err)
assert.Equal(t, expectedStruct, input)
})
Expand All @@ -157,7 +158,9 @@ func TestEndToEnd(t *testing.T) {
},
MaxTokens: 256,
Temperature: 0.0,
Functions: []openai.FunctionDefinition{openai.FunctionDefinition(fi)},
Tools: []openai.Tool{
{Type: "function", Function: openai.FunctionDefinition(fi)},
},
}

ctx := context.Background()
Expand All @@ -166,11 +169,11 @@ func TestEndToEnd(t *testing.T) {
assert.Equal(t, resp.Model, "gpt-3.5-turbo-0613")
assert.NotEmpty(t, resp.Choices)
assert.Empty(t, resp.Choices[0].Message.Content)
assert.NotNil(t, resp.Choices[0].Message.FunctionCall)
assert.Equal(t, resp.Choices[0].Message.FunctionCall.Name, "ChatCompletion")
assert.NotNil(t, resp.Choices[0].Message.ToolCalls)
assert.Equal(t, resp.Choices[0].Message.ToolCalls[0].Function.Name, "ChatCompletion")

parser := gollum.NewJSONParserGeneric[openai.ChatCompletionRequest](false)
input, err := parser.Parse(ctx, []byte(resp.Choices[0].Message.FunctionCall.Arguments))
input, err := parser.Parse(ctx, []byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments))
assert.NoError(t, err)
assert.NotEmpty(t, input)

Expand Down Expand Up @@ -205,7 +208,9 @@ func TestEndToEnd(t *testing.T) {
},
MaxTokens: 256,
Temperature: 0.0,
Functions: []openai.FunctionDefinition{openai.FunctionDefinition(fi)},
Tools: []openai.Tool{
{Type: "function", Function: openai.FunctionDefinition(fi)},
},
}
ctx := context.Background()
resp, err := api.SendRequest(ctx, chatRequest)
Expand All @@ -214,7 +219,7 @@ func TestEndToEnd(t *testing.T) {
assert.Equal(t, 0, 1)

parser := gollum.NewJSONParserGeneric[blobNode](false)
input, err := parser.Parse(ctx, []byte(resp.Choices[0].Message.FunctionCall.Arguments))
input, err := parser.Parse(ctx, []byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments))
assert.NoError(t, err)
assert.NotEmpty(t, input)
assert.Equal(t, input, blobNode{
Expand Down Expand Up @@ -282,7 +287,9 @@ Output: What is the population of Jason's home country?
},
MaxTokens: 256,
Temperature: 0.0,
Functions: []openai.FunctionDefinition{openai.FunctionDefinition(fi)},
Tools: []openai.Tool{
{Type: "function", Function: openai.FunctionDefinition(fi)},
},
}
ctx := context.Background()
resp, err := api.SendRequest(ctx, chatRequest)
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ require (
github.com/klauspost/compress v1.17.2
github.com/pkg/errors v0.9.1
github.com/santhosh-tekuri/jsonschema/v5 v5.3.1
github.com/sashabaranov/go-openai v1.17.2
github.com/sashabaranov/go-openai v1.17.7
github.com/stretchr/testify v1.8.4
github.com/viterin/vek v0.4.2
go.uber.org/mock v0.3.0
Expand Down
5 changes: 5 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFF
github.com/klauspost/compress v1.17.2 h1:RlWWUY/Dr4fL8qk9YG7DTZ7PDgME2V4csBXA8L/ixi4=
github.com/klauspost/compress v1.17.2/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
Expand All @@ -102,6 +103,8 @@ github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 h1:lZUw3E0/J3roVtGQ+SCrUrg3ON6Ng
github.com/santhosh-tekuri/jsonschema/v5 v5.3.1/go.mod h1:uToXkOrWAZ6/Oc07xWQrPOhJotwFIyu2bBVN41fcDUY=
github.com/sashabaranov/go-openai v1.17.2 h1:Uj1Msqh43S9XhjUXYyOqOHMiRQtgQXCo5O0FeWZz7tU=
github.com/sashabaranov/go-openai v1.17.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sashabaranov/go-openai v1.17.7 h1:MPcAwlwbeo7ZmhQczoOgZBHtIBY1TfZqsdx6+/ndloM=
github.com/sashabaranov/go-openai v1.17.7/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
Expand Down Expand Up @@ -169,6 +172,7 @@ google.golang.org/api v0.150.0/go.mod h1:ccy+MJ6nrYFgE3WgRx/AMXOxOmU8Q4hSa+jjibz
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c=
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
Expand Down Expand Up @@ -198,6 +202,7 @@ google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
Expand Down
Loading

0 comments on commit 7c4cc75

Please sign in to comment.