Skip to content

Commit

Permalink
Allow extra headers for specific conversations too
Browse files Browse the repository at this point in the history
This commit adds support for setting extra HTTP headers for all
requests issued as part of a specific conversation. These headers
will be sent in addition to any extra headers defined for the
backend itself, if any, and will take precedence over them.

The bedrock implementation does not support extra headers at all
at this point.
  • Loading branch information
ido50 committed Jul 1, 2024
1 parent 1ad9e02 commit 374205d
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 13 deletions.
3 changes: 3 additions & 0 deletions libaiac/bedrock/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,6 @@ func (conv *Conversation) Messages() []types.Message {
}
return msgs
}

// AddHeader is a noop for the bedrock implementation
func (conv *Conversation) AddHeader(_ string, _ string) {}
29 changes: 23 additions & 6 deletions libaiac/ollama/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ import (
// Conversation is a struct used to converse with an Ollama chat model. It
// maintains all messages sent/received in order to maintain context.
type Conversation struct {
backend *Ollama
model string
messages []types.Message
backend *Ollama
model string
messages []types.Message
extraHeaders map[string]string
}

type chatResponse struct {
Expand Down Expand Up @@ -55,7 +56,7 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
Content: prompt,
})

err = conv.backend.NewRequest("POST", "/chat").
req := conv.backend.NewRequest("POST", "/chat").
JSONBody(map[string]interface{}{
"model": conv.model,
"messages": conv.messages,
Expand All @@ -64,8 +65,13 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
},
"stream": false,
}).
Into(&answer).
RunContext(ctx)
Into(&answer)

for key, val := range conv.extraHeaders {
req.Header(key, val)
}

err = req.RunContext(ctx)
if err != nil {
return res, fmt.Errorf("failed sending prompt: %w", err)
}
Expand All @@ -92,3 +98,14 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
func (conv *Conversation) Messages() []types.Message {
return conv.messages
}

// AddHeader adds an extra HTTP header that will be added to every HTTP
// request issued as part of this conversation. Any headers added will be in
// addition to any extra headers defined for the backend itself, and will
// take precedence over them.
func (conv *Conversation) AddHeader(key, val string) {
if conv.extraHeaders == nil {
conv.extraHeaders = make(map[string]string)
}
conv.extraHeaders[key] = val
}
29 changes: 23 additions & 6 deletions libaiac/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ import (
// maintains all messages sent/received in order to maintain context just like
// using ChatGPT.
type Conversation struct {
backend *OpenAI
model string
messages []types.Message
backend *OpenAI
model string
messages []types.Message
extraHeaders map[string]string
}

type chatResponse struct {
Expand Down Expand Up @@ -67,15 +68,20 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
apiVersion = fmt.Sprintf("?api-version=%s", conv.backend.apiVersion)
}

err = conv.backend.
req := conv.backend.
NewRequest("POST", fmt.Sprintf("/chat/completions%s", apiVersion)).
JSONBody(map[string]interface{}{
"model": conv.model,
"messages": conv.messages,
"temperature": 0.2,
}).
Into(&answer).
RunContext(ctx)
Into(&answer)

for key, val := range conv.extraHeaders {
req.Header(key, val)
}

err = req.RunContext(ctx)
if err != nil {
return res, fmt.Errorf("failed sending prompt: %w", err)
}
Expand Down Expand Up @@ -104,3 +110,14 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
func (conv *Conversation) Messages() []types.Message {
return conv.messages
}

// AddHeader adds an extra HTTP header that will be added to every HTTP
// request issued as part of this conversation. Any headers added will be in
// addition to any extra headers defined for the backend itself, and will
// take precedence over them.
func (conv *Conversation) AddHeader(key, val string) {
if conv.extraHeaders == nil {
conv.extraHeaders = make(map[string]string)
}
conv.extraHeaders[key] = val
}
9 changes: 8 additions & 1 deletion libaiac/types/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ type Conversation interface {
Send(context.Context, string) (Response, error)

// Messages returns all the messages that have been exchanged between the
// user and the assistant up to this point
// user and the assistant up to this point.
Messages() []Message

// AddHeader adds an extra HTTP header that will be added to every HTTP
// request issued as part of this conversation. Any headers added will be in
// addition to any extra headers defined for the backend itself, and will
// take precedence over them. Not all providers may support this
// (specifically, bedrock doesn't).
AddHeader(string, string)
}

0 comments on commit 374205d

Please sign in to comment.