Skip to content

Commit

Permalink
fix: dispatch tools (#24)
Browse files Browse the repository at this point in the history
* fix: dispatch tools

* chore: skip
  • Loading branch information
stillmatic authored Nov 16, 2023
1 parent 7c4cc75 commit 07a9aa3
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 16 deletions.
33 changes: 23 additions & 10 deletions dispatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,13 @@ type OpenAIDispatcherConfig struct {
// For any type T and prompt, it will generate and parse the response into T.
type OpenAIDispatcher[T any] struct {
*OpenAIDispatcherConfig
completer ChatCompleter
ti openai.Tool
parser Parser[T]
completer ChatCompleter
ti openai.Tool
systemPrompt string
parser Parser[T]
}

func NewOpenAIDispatcher[T any](name, description string, completer ChatCompleter, cfg *OpenAIDispatcherConfig) *OpenAIDispatcher[T] {
func NewOpenAIDispatcher[T any](name, description, systemPrompt string, completer ChatCompleter, cfg *OpenAIDispatcherConfig) *OpenAIDispatcher[T] {
// note: name must not have spaces - valid json
// we won't check here but the openai client will throw an error
var t T
Expand All @@ -65,12 +66,13 @@ func NewOpenAIDispatcher[T any](name, description string, completer ChatComplete
completer: completer,
ti: ti,
parser: parser,
systemPrompt: systemPrompt,
}
}

func (d *OpenAIDispatcher[T]) Prompt(ctx context.Context, prompt string) (T, error) {
var output T
model := openai.GPT3Dot5Turbo0613
model := openai.GPT3Dot5Turbo1106
temperature := float32(0.0)
maxTokens := 512
if d.OpenAIDispatcherConfig != nil {
Expand All @@ -85,24 +87,35 @@ func (d *OpenAIDispatcher[T]) Prompt(ctx context.Context, prompt string) (T, err
}
}

resp, err := d.completer.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
req := openai.ChatCompletionRequest{
Model: model,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: d.systemPrompt,
},
{
Role: openai.ChatMessageRoleUser,
Content: prompt,
},
},
Tools: []openai.Tool{d.ti},
ToolChoice: d.ti.Function.Name,
Tools: []openai.Tool{d.ti},
ToolChoice: openai.ToolChoice{
Type: "function",
Function: openai.ToolFunction{
Name: d.ti.Function.Name,
}},
Temperature: temperature,
MaxTokens: maxTokens,
})
}

resp, err := d.completer.CreateChatCompletion(ctx, req)
if err != nil {
return output, err
}

output, err = d.parser.Parse(ctx, []byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments))
toolOutput := resp.Choices[0].Message.ToolCalls[0].Function.Arguments
output, err = d.parser.Parse(ctx, []byte(toolOutput))
if err != nil {
return output, err
}
Expand Down
30 changes: 25 additions & 5 deletions dispatch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gollum_test

import (
"context"
"os"
"testing"
"text/template"

Expand All @@ -21,6 +22,10 @@ type templateInput struct {
Topic string
}

type wordCountOutput struct {
Count int `json:"count" jsonschema:"required" jsonschema_description:"The number of words in the sentence"`
}

func TestDummyDispatcher(t *testing.T) {
d := gollum.NewDummyDispatcher[testInput]()

Expand All @@ -46,7 +51,8 @@ func TestDummyDispatcher(t *testing.T) {
func TestOpenAIDispatcher(t *testing.T) {
ctrl := gomock.NewController(t)
completer := mock_gollum.NewMockChatCompleter(ctrl)
d := gollum.NewOpenAIDispatcher[testInput]("random_conversation", "Given a topic, return random words", completer, nil)
systemPrompt := "When prompted, use the tool."
d := gollum.NewOpenAIDispatcher[testInput]("random_conversation", "Given a topic, return random words", systemPrompt, completer, nil)

ctx := context.Background()
expected := testInput{
Expand All @@ -58,15 +64,19 @@ func TestOpenAIDispatcher(t *testing.T) {
fi := openai.FunctionDefinition(gollum.StructToJsonSchema("random_conversation", "Given a topic, return random words", testInput{}))
ti := openai.Tool{Type: "function", Function: fi}
expectedRequest := openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo0613,
Model: openai.GPT3Dot5Turbo1106,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Role: openai.ChatMessageRoleUser,
Content: "Tell me about dinosaurs",
},
},
Tools: []openai.Tool{ti},
ToolChoice: fi.Name,
Tools: []openai.Tool{ti},
ToolChoice: openai.ToolChoice{
Type: "function",
Function: openai.ToolFunction{
Name: "random_conversation",
}},
MaxTokens: 512,
Temperature: 0.0,
}
Expand Down Expand Up @@ -127,3 +137,13 @@ func TestOpenAIDispatcher(t *testing.T) {
assert.Equal(t, expected, output)
})
}

func TestDispatchIntegration(t *testing.T) {
t.Skip("Skipping integration test")
completer := openai.NewClient(os.Getenv("OPENAI_API_KEY"))
systemPrompt := "When prompted, use the tool on the user's input."
d := gollum.NewOpenAIDispatcher[wordCountOutput]("wordCounter", "count the number of words in a sentence", systemPrompt, completer, nil)
output, err := d.Prompt(context.Background(), "I like dinosaurs")
assert.NoError(t, err)
assert.Equal(t, 3, output.Count)
}
2 changes: 1 addition & 1 deletion functions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func TestEndToEnd(t *testing.T) {
fi := gollum.StructToJsonSchema("weather", "Get the current weather in a given location", getWeatherInput{})

chatRequest := openai.ChatCompletionRequest{
Model: "gpt-3.5-turbo-0613",
Model: openai.GPT3Dot5Turbo1106,
Messages: []openai.ChatCompletionMessage{
{
Role: "user",
Expand Down

0 comments on commit 07a9aa3

Please sign in to comment.