Skip to content

Commit

Permalink
feat: allowed add audit query params to downstream proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
wenerme committed Nov 17, 2023
1 parent 1847020 commit 59a4b7d
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 9 deletions.
107 changes: 107 additions & 0 deletions code/contexts/keys.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package contexts

import (
"context"
"fmt"
"net/url"
"reflect"
)

type ChatContext struct {
Tenant string
ChatID string
ChatType string
MessageID string
MessageType string
SenderOpenID string
SenderType string
SenderUnionID string
SenderUserID string
SessionID string
}

func (cc *ChatContext) Encode() string {
v := url.Values{}
v.Set("feishu.tenant", cc.Tenant)
v.Set("feishu.session_id", cc.SessionID)
v.Set("feishu.chat_id", cc.ChatID)
v.Set("feishu.chat_type", cc.ChatType)
v.Set("feishu.message_id", cc.MessageID)
v.Set("feishu.message_type", cc.MessageType)
v.Set("feishu.sender_user_id", cc.SenderUserID)
v.Set("feishu.sender_union_id", cc.SenderUnionID)
v.Set("feishu.sender_open_id", cc.SenderOpenID)
v.Set("feishu.sender_type", cc.SenderType)
for k, vv := range v {
if len(vv) == 0 {
delete(v, k)
}
}
return v.Encode()
}

var ChatContextKey = CreateContextKey[*ChatContext]()

type ContextKey[T any] interface {
Value(ctx context.Context) (T, bool)
Get(ctx context.Context) T
Must(ctx context.Context) T
WithValue(ctx context.Context, val T) context.Context
}

type key[T any] struct {
opts CreateContextKeyOptions[T]
}

func (k key[T]) Value(ctx context.Context) (T, bool) {
o, ok := ctx.Value(k.opts.key).(T)
return o, ok
}

func (k key[T]) Get(ctx context.Context) T {
o, _ := ctx.Value(k.opts.key).(T)
return o
}

func (k key[T]) Must(ctx context.Context) T {
o, ok := ctx.Value(k.opts.key).(T)
if !ok {
panic(fmt.Errorf("%s not found in context", k.String()))
}
return o
}

func (k key[T]) WithValue(ctx context.Context, val T) context.Context {
return context.WithValue(ctx, k.opts.key, val)
}

func (k key[T]) String() string {
name := k.opts.Name
if name != "" {
name = "@" + name
}
return fmt.Sprintf("ContextKey(%s%s)", reflect.TypeOf(new(T)).Elem().String(), name)
}

var _ ContextKey[string] = (*key[string])(nil)

type CreateContextKeyOptions[T any] struct {
Name string
key any
}

func CreateContextKey[T any](opts ...CreateContextKeyOptions[T]) ContextKey[T] {
var opt CreateContextKeyOptions[T]
if len(opts) > 0 {
// reduce
for _, o := range opts {
opt = o
}
}
if opt.Name != "" {
opt.key = opt.Name
} else {
opt.key = reflect.TypeOf(new(T)).Elem()
}
return &key[T]{opts: opt}
}
3 changes: 2 additions & 1 deletion code/handlers/event_common_action.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package handlers
import (
"context"
"fmt"

"start-feishubot/contexts"
"start-feishubot/initialization"
"start-feishubot/services/openai"
"start-feishubot/utils"
Expand All @@ -21,6 +21,7 @@ type MsgInfo struct {
imageKey string
sessionId *string
mention []*larkim.MentionEvent
Context *contexts.ChatContext
}
type ActionInfo struct {
handler *MessageHandler
Expand Down
9 changes: 7 additions & 2 deletions code/handlers/event_msg_action.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package handlers

import (
"context"
"encoding/json"
"fmt"
"log"
"start-feishubot/contexts"
"strings"
"time"

Expand Down Expand Up @@ -37,11 +39,14 @@ func (*MessageAction) Execute(a *ActionInfo) bool {
Role: "user", Content: a.info.qParsed,
})

ctx := context.Background()
ctx = contexts.ChatContextKey.WithValue(ctx, a.info.Context)

//fmt.Println("msg", msg)
//logger.Debug("msg", msg)
// get ai mode as temperature
aiMode := a.handler.sessionCache.GetAIMode(*a.info.sessionId)
completions, err := a.handler.gpt.Completions(msg, aiMode)
completions, err := a.handler.gpt.Completions(ctx, msg, aiMode)
if err != nil {
replyMsg(*a.ctx, fmt.Sprintf(
"🤖️:消息机器人摆烂了,请稍后再试~\n错误信息: %v", err), a.info.msgId)
Expand Down Expand Up @@ -70,7 +75,7 @@ func (*MessageAction) Execute(a *ActionInfo) bool {
return true
}

//判断msg中的是否包含system role
// 判断msg中的是否包含system role
func hasSystemRole(msg []openai.Messages) bool {
for _, m := range msg {
if m.Role == "system" {
Expand Down
25 changes: 25 additions & 0 deletions code/handlers/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"fmt"
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
"log"
"start-feishubot/contexts"
"start-feishubot/logger"
"strings"

Expand Down Expand Up @@ -85,6 +87,27 @@ func (m MessageHandler) msgReceivedHandler(ctx context.Context, event *larkim.P2
sessionId: sessionId,
mention: mention,
}
{
get := func(s *string) string {
if s == nil {
return ""
}
return *s
}
cc := &contexts.ChatContext{
MessageID: get(msgInfo.msgId),
MessageType: msgInfo.msgType,
ChatID: get(msgInfo.chatId),
ChatType: get(event.Event.Message.ChatType),
SessionID: get(msgInfo.sessionId),
SenderUserID: get(event.Event.Sender.SenderId.UserId),
SenderOpenID: get(event.Event.Sender.SenderId.OpenId),
SenderUnionID: get(event.Event.Sender.SenderId.UnionId),
SenderType: get(event.Event.Sender.SenderType),
Tenant: get(event.Event.Sender.TenantKey),
}
msgInfo.Context = cc
}
data := &ActionInfo{
ctx: &ctx,
handler: &m,
Expand Down Expand Up @@ -127,6 +150,8 @@ func (m MessageHandler) judgeIfMentionMe(mention []*larkim.
if len(mention) != 1 {
return false
}
// for simple debugging, find a way to pass the info to endpoint
log.Printf("mention: name=%v key=%v id.userid=%v id.openid=%v", *mention[0].Name, *mention[0].Key, *mention[0].Id.UserId, *mention[0].Id.OpenId)
return *mention[0].Name == m.config.FeishuBotName
}

Expand Down
2 changes: 2 additions & 0 deletions code/initialization/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ type Config struct {
AzureResourceName string
AzureOpenaiToken string
StreamMode bool
AuditQueryParams bool
}

var (
Expand Down Expand Up @@ -89,6 +90,7 @@ func LoadConfig(cfg string) *Config {
AzureResourceName: getViperStringValue("AZURE_RESOURCE_NAME", ""),
AzureOpenaiToken: getViperStringValue("AZURE_OPENAI_TOKEN", ""),
StreamMode: getViperBoolValue("STREAM_MODE", false),
AuditQueryParams: getViperBoolValue("AUDIT_QUERY_PARAMS", false),
}

return config
Expand Down
22 changes: 17 additions & 5 deletions code/services/openai/gpt3.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package openai

import (
"context"
"errors"
"start-feishubot/contexts"
"start-feishubot/initialization"
"start-feishubot/logger"
"strings"

Expand Down Expand Up @@ -68,7 +71,7 @@ func (msg *Messages) CalculateTokenLength() int {
return tokenizer.MustCalToken(text)
}

func (gpt *ChatGPT) Completions(msg []Messages, aiMode AIMode) (resp Messages,
func (gpt *ChatGPT) Completions(ctx context.Context, msg []Messages, aiMode AIMode) (resp Messages,
err error) {
requestBody := ChatGPTRequestBody{
Model: gpt.Model,
Expand All @@ -80,14 +83,23 @@ func (gpt *ChatGPT) Completions(msg []Messages, aiMode AIMode) (resp Messages,
PresencePenalty: 0,
}
gptResponseBody := &ChatGPTResponseBody{}
url := gpt.FullUrl("chat/completions")
fullUrl := gpt.FullUrl("chat/completions")
if initialization.GetConfig().AuditQueryParams {
cc := contexts.ChatContextKey.Must(ctx)
if cc != nil {
encode := cc.Encode()
if encode != "" {
fullUrl = fullUrl + "?" + encode
}
}
}
//fmt.Println(url)
logger.Debug(url)
logger.Debug(fullUrl)
logger.Debug("request body ", requestBody)
if url == "" {
if fullUrl == "" {
return resp, errors.New("无法获取openai请求地址")
}
err = gpt.sendRequestWithBodyType(url, "POST", jsonBody, requestBody, gptResponseBody)
err = gpt.sendRequestWithBodyType(fullUrl, "POST", jsonBody, requestBody, gptResponseBody)
if err == nil && len(gptResponseBody.Choices) > 0 {
resp = gptResponseBody.Choices[0].Message
} else {
Expand Down
2 changes: 1 addition & 1 deletion code/services/openai/gpt3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func TestCompletions(t *testing.T) {
{Role: "user", Content: "翻译这段话: The assistant messages help store prior responses. They can also be written by a developer to help give examples of desired behavior."},
}
gpt := NewChatGPT(*config)
resp, err := gpt.Completions(msgs, Balance)
resp, err := gpt.Completions(nil, msgs, Balance)
if err != nil {
t.Errorf("TestCompletions failed with error: %v", err)
}
Expand Down

0 comments on commit 59a4b7d

Please sign in to comment.