Skip to content

Commit

Permalink
Merge pull request #823 from trheyi/main
Browse files Browse the repository at this point in the history
Refactor Neo API assistant message handling and content structure
  • Loading branch information
trheyi authored Jan 19, 2025
2 parents 1d9d9b2 + 830bbf7 commit 8ee8429
Show file tree
Hide file tree
Showing 8 changed files with 313 additions and 157 deletions.
61 changes: 26 additions & 35 deletions neo/assistant/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,23 @@ func (ast *Assistant) Execute(c *gin.Context, ctx chatctx.Context, input string,
// Run init hook
res, err := ast.HookInit(c, ctx, messages, options)
if err != nil {
chatMessage.New().
Assistant(ast.ID, ast.Name, ast.Avatar).
Error(err).
Done().
Write(c.Writer)
return err
}

// Switch to the new assistant if necessary
if res != nil && res.AssistantID != ctx.AssistantID {
newAst, err := Get(res.AssistantID)
if err != nil {
chatMessage.New().
Assistant(ast.ID, ast.Name, ast.Avatar).
Error(err).
Done().
Write(c.Writer)
return err
}
*ast = *newAst
Expand Down Expand Up @@ -163,16 +173,16 @@ func (next *NextAction) Execute(c *gin.Context, ctx chatctx.Context) error {
func (ast *Assistant) handleChatStream(c *gin.Context, ctx chatctx.Context, messages []chatMessage.Message, options map[string]interface{}) error {
clientBreak := make(chan bool, 1)
done := make(chan bool, 1)
content := chatMessage.NewContent("text")
contents := chatMessage.NewContents()

// Chat with AI in background
go func() {
err := ast.streamChat(c, ctx, messages, options, clientBreak, done, content)
err := ast.streamChat(c, ctx, messages, options, clientBreak, done, contents)
if err != nil {
chatMessage.New().Error(err).Done().Write(c.Writer)
}

ast.saveChatHistory(ctx, messages, content)
ast.saveChatHistory(ctx, messages, contents)
done <- true
}()

Expand All @@ -194,7 +204,7 @@ func (ast *Assistant) streamChat(
options map[string]interface{},
clientBreak chan bool,
done chan bool,
content *chatMessage.Content) error {
contents *chatMessage.Contents) error {

return ast.Chat(c.Request.Context(), messages, options, func(data []byte) int {
select {
Expand All @@ -210,7 +220,7 @@ func (ast *Assistant) streamChat(
// Handle error
if msg.Type == "error" {
value := msg.String()
res, hookErr := ast.HookFail(c, ctx, messages, content.String(), fmt.Errorf("%s", value))
res, hookErr := ast.HookFail(c, ctx, messages, contents.JSON(), fmt.Errorf("%s", value))
if hookErr == nil && res != nil && (res.Output != "" || res.Error != "") {
value = res.Output
if res.Error != "" {
Expand All @@ -221,30 +231,13 @@ func (ast *Assistant) streamChat(
return 0 // break
}

// Handle tool call
if msg.Type == "tool_calls" {
content.SetType("function") // Set type to function
// Set id
if id, ok := msg.Props["id"].(string); ok && id != "" {
content.SetID(id)
}

// Set name
if name, ok := msg.Props["name"].(string); ok && name != "" {
content.SetName(name)
}
}

// Append content and send message
msg.AppendTo(contents)
value := msg.String()
content.Append(value)
if value != "" {
// Handle stream
res, err := ast.HookStream(c, ctx, messages, content.String(), content.Type == "function")
res, err := ast.HookStream(c, ctx, messages, contents.Data)
if err == nil && res != nil {
if res.Output != "" {
value = res.Output
}

if res.Next != nil {
err = res.Next.Execute(c, ctx)
Expand Down Expand Up @@ -274,18 +267,16 @@ func (ast *Assistant) streamChat(

// Complete the stream
if msg.IsDone {
if value == "" {
msg.Write(c.Writer)
}
// if value == "" {
// msg.Write(c.Writer)
// }

// Call HookDone
content.SetStatus(chatMessage.ContentStatusDone)
res, hookErr := ast.HookDone(c, ctx, messages, content.String(), content.Type == "function")
res, hookErr := ast.HookDone(c, ctx, messages, contents.Data)
if hookErr == nil && res != nil {
if res.Output != "" {
if res.Output != nil {
chatMessage.New().
Map(map[string]interface{}{
"text": res.Output,
"text": res.Input,
"done": true,
}).
Write(c.Writer)
Expand Down Expand Up @@ -322,8 +313,8 @@ func (ast *Assistant) streamChat(
}

// saveChatHistory saves the chat history if storage is available
func (ast *Assistant) saveChatHistory(ctx chatctx.Context, messages []chatMessage.Message, content *chatMessage.Content) {
if len(content.Bytes) > 0 && ctx.Sid != "" && len(messages) > 0 {
func (ast *Assistant) saveChatHistory(ctx chatctx.Context, messages []chatMessage.Message, contents *chatMessage.Contents) {
if len(contents.Data) > 0 && ctx.Sid != "" && len(messages) > 0 {
userMessage := messages[len(messages)-1]
data := []map[string]interface{}{
{
Expand All @@ -333,7 +324,7 @@ func (ast *Assistant) saveChatHistory(ctx chatctx.Context, messages []chatMessag
},
{
"role": "assistant",
"content": content.String(),
"content": contents.JSON(),
"name": ctx.Sid,
"assistant_id": ast.ID,
"assistant_name": ast.Name,
Expand Down
72 changes: 63 additions & 9 deletions neo/assistant/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"

"github.com/gin-gonic/gin"
jsoniter "github.com/json-iterator/go"
chatctx "github.com/yaoapp/yao/neo/context"
"github.com/yaoapp/yao/neo/message"
)
Expand Down Expand Up @@ -34,6 +35,17 @@ func (ast *Assistant) HookInit(c *gin.Context, context chatctx.Context, input []
response.ChatID = res
}

// input
if input, has := v["input"]; has {
raw, _ := jsoniter.MarshalToString(input)
vv := []message.Message{}
err := jsoniter.UnmarshalFromString(raw, &vv)
if err != nil {
return nil, err
}
response.Input = vv
}

if res, ok := v["next"].(map[string]interface{}); ok {
response.Next = &NextAction{}
if name, ok := res["action"].(string); ok {
Expand All @@ -57,13 +69,13 @@ func (ast *Assistant) HookInit(c *gin.Context, context chatctx.Context, input []
}

// HookStream Handle streaming response from LLM
func (ast *Assistant) HookStream(c *gin.Context, context chatctx.Context, input []message.Message, output string, toolcall bool) (*ResHookStream, error) {
func (ast *Assistant) HookStream(c *gin.Context, context chatctx.Context, input []message.Message, output []message.Data) (*ResHookStream, error) {

// Create timeout context
ctx, cancel := ast.createTimeoutContext(c)
defer cancel()

v, err := ast.call(ctx, "Stream", context, input, output, toolcall, c.Writer)
v, err := ast.call(ctx, "Stream", context, input, output, c.Writer)
if err != nil {
if err.Error() == HookErrorMethodNotFound {
return nil, nil
Expand All @@ -75,8 +87,24 @@ func (ast *Assistant) HookStream(c *gin.Context, context chatctx.Context, input
switch v := v.(type) {
case map[string]interface{}:
if res, ok := v["output"].(string); ok {
response.Output = res
vv := []message.Data{}
err := jsoniter.UnmarshalFromString(res, &vv)
if err != nil {
return nil, err
}
response.Output = vv
}

if res, ok := v["output"].([]interface{}); ok {
vv := []message.Data{}
raw, _ := jsoniter.MarshalToString(res)
err := jsoniter.UnmarshalFromString(raw, &vv)
if err != nil {
return nil, err
}
response.Output = vv
}

if res, ok := v["next"].(map[string]interface{}); ok {
response.Next = &NextAction{}
if name, ok := res["action"].(string); ok {
Expand All @@ -93,19 +121,24 @@ func (ast *Assistant) HookStream(c *gin.Context, context chatctx.Context, input
}

case string:
response.Output = v
vv := []message.Data{}
err := jsoniter.UnmarshalFromString(v, &vv)
if err != nil {
return nil, err
}
response.Output = vv
}

return response, nil
}

// HookDone Handle completion of assistant response
func (ast *Assistant) HookDone(c *gin.Context, context chatctx.Context, input []message.Message, output string, toolcall bool) (*ResHookDone, error) {
func (ast *Assistant) HookDone(c *gin.Context, context chatctx.Context, input []message.Message, output []message.Data) (*ResHookDone, error) {
// Create timeout context
ctx, cancel := ast.createTimeoutContext(c)
defer cancel()

v, err := ast.call(ctx, "Done", context, input, output, toolcall, c.Writer)
v, err := ast.call(ctx, "Done", context, input, output, c.Writer)
if err != nil {
if err.Error() == HookErrorMethodNotFound {
return nil, nil
Expand All @@ -121,8 +154,24 @@ func (ast *Assistant) HookDone(c *gin.Context, context chatctx.Context, input []
switch v := v.(type) {
case map[string]interface{}:
if res, ok := v["output"].(string); ok {
response.Output = res
vv := []message.Data{}
err := jsoniter.UnmarshalFromString(res, &vv)
if err != nil {
return nil, err
}
response.Output = vv
}

if res, ok := v["output"].([]interface{}); ok {
vv := []message.Data{}
raw, _ := jsoniter.MarshalToString(res)
err := jsoniter.UnmarshalFromString(raw, &vv)
if err != nil {
return nil, err
}
response.Output = vv
}

if res, ok := v["next"].(map[string]interface{}); ok {
response.Next = &NextAction{}
if name, ok := res["action"].(string); ok {
Expand All @@ -133,7 +182,12 @@ func (ast *Assistant) HookDone(c *gin.Context, context chatctx.Context, input []
}
}
case string:
response.Output = v
vv := []message.Data{}
err := jsoniter.UnmarshalFromString(v, &vv)
if err != nil {
return nil, err
}
response.Output = vv
}

return response, nil
Expand Down Expand Up @@ -185,7 +239,7 @@ func (ast *Assistant) HookFail(c *gin.Context, context chatctx.Context, input []

// createTimeoutContext creates a timeout context with 5 seconds timeout
func (ast *Assistant) createTimeoutContext(c *gin.Context) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
return ctx, cancel
}

Expand Down
4 changes: 4 additions & 0 deletions neo/assistant/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ var defaultConnector string = "" // default connector

// LoadBuiltIn load the built-in assistants
func LoadBuiltIn() error {

// Clear the cache
loaded.Clear()

root := `/assistants`
app, err := fs.Get("app")
if err != nil {
Expand Down
8 changes: 4 additions & 4 deletions neo/assistant/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,16 @@ type ResHookInit struct {

// ResHookStream the response of the stream hook
type ResHookStream struct {
Silent bool `json:"silent,omitempty"` // Whether to suppress the output
Next *NextAction `json:"next,omitempty"` // The next action
Output string `json:"output,omitempty"` // The output
Silent bool `json:"silent,omitempty"` // Whether to suppress the output
Next *NextAction `json:"next,omitempty"` // The next action
Output []message.Data `json:"output,omitempty"` // The output
}

// ResHookDone the response of the done hook
type ResHookDone struct {
Next *NextAction `json:"next,omitempty"`
Input []message.Message `json:"input,omitempty"`
Output string `json:"output,omitempty"`
Output []message.Data `json:"output,omitempty"`
}

// ResHookFail the response of the fail hook
Expand Down
Loading

0 comments on commit 8ee8429

Please sign in to comment.