Skip to content

Commit

Permalink
feat: add too_call_once and two_model_chat to graph examples
Browse files Browse the repository at this point in the history
Change-Id: I4bc86918954cae88cc8f5b4fb4deea0d0895bfc7
  • Loading branch information
shentongmartin committed Jan 24, 2025
1 parent ff3a7a6 commit f3d1920
Show file tree
Hide file tree
Showing 2 changed files with 288 additions and 0 deletions.
159 changes: 159 additions & 0 deletions compose/graph/tool_call_once/tool_call_once.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package main

import (
"context"
"errors"
"os"

"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/components/tool/utils"
. "github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"

"github.com/cloudwego/eino-examples/internal/gptr"
"github.com/cloudwego/eino-examples/internal/logs"
)

func main() {
openAIBaseURL := os.Getenv("OPENAI_BASE_URL")
openAIAPIKey := os.Getenv("OPENAI_API_KEY")
modelName := os.Getenv("MODEL_NAME")

ctx := context.Background()

systemTpl := `你是一名房产经纪人,结合用户的薪酬和工作,使用 user_info API,为其提供相关的房产信息。邮箱是必须的`
chatTpl := prompt.FromMessages(schema.FString,
schema.SystemMessage(systemTpl),
schema.MessagesPlaceholder("message_histories", true),
schema.UserMessage("{query}"),
)

modelConf := &openai.ChatModelConfig{
BaseURL: openAIBaseURL,
APIKey: openAIAPIKey,
ByAzure: true,
Model: modelName,
Temperature: gptr.Of(float32(0.7)),
APIVersion: "2024-06-01",
}

chatModel, err := openai.NewChatModel(ctx, modelConf)
if err != nil {
logs.Errorf("NewChatModel failed, err=%v", err)
return
}

userInfoTool := utils.NewTool(
&schema.ToolInfo{
Name: "user_info",
Desc: "根据用户的姓名和邮箱,查询用户的公司、职位、薪酬信息",
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{
"name": {
Type: "string",
Desc: "用户的姓名",
},
"email": {
Type: "string",
Desc: "用户的邮箱",
},
}),
},
func(ctx context.Context, input *userInfoRequest) (output *userInfoResponse, err error) {
return &userInfoResponse{
Name: input.Name,
Email: input.Email,
Company: "Bytedance",
Position: "CEO",
Salary: "9999",
}, nil
})

info, err := userInfoTool.Info(ctx)
if err != nil {
logs.Errorf("Get ToolInfo failed, err=%v", err)
return
}

err = chatModel.BindTools([]*schema.ToolInfo{info})
if err != nil {
logs.Errorf("BindTools failed, err=%v", err)
return
}

toolsNode, err := NewToolNode(ctx, &ToolsNodeConfig{
Tools: []tool.BaseTool{userInfoTool},
})
if err != nil {
logs.Errorf("NewToolNode failed, err=%v", err)
return
}

takeOne := InvokableLambda(func(ctx context.Context, input []*schema.Message) (*schema.Message, error) {
if len(input) == 0 {
return nil, errors.New("input is empty")
}
return input[0], nil
})

const (
nodeModel = "node_model"
nodeTools = "node_tools"
nodeTemplate = "node_template"
nodeConverter = "node_converter"
)

branch := NewStreamGraphBranch(func(ctx context.Context, input *schema.StreamReader[*schema.Message]) (string, error) {
defer input.Close()
msg, err := input.Recv()
if err != nil {
return "", err
}

if len(msg.ToolCalls) > 0 {
return "tools", nil
}

return END, nil
}, map[string]bool{END: true, nodeTools: true})

graph := NewGraph[map[string]any, *schema.Message]()

_ = graph.AddChatTemplateNode(nodeTemplate, chatTpl)
_ = graph.AddChatModelNode(nodeModel, chatModel)
_ = graph.AddToolsNode(nodeTools, toolsNode)
_ = graph.AddLambdaNode(nodeConverter, takeOne)

_ = graph.AddEdge(START, nodeTemplate)
_ = graph.AddEdge(nodeTemplate, nodeModel)
_ = graph.AddBranch(nodeModel, branch)
_ = graph.AddEdge(nodeTools, nodeConverter)
_ = graph.AddEdge(nodeConverter, END)

r, err := graph.Compile(ctx)
if err != nil {
logs.Errorf("Compile failed, err=%v", err)
}

out, err := r.Invoke(ctx, map[string]any{"query": "我叫 zhangsan, 邮箱是 [email protected], 帮我推荐一处房产"})
if err != nil {
logs.Errorf("Invoke failed, err=%v", err)
return
}

logs.Infof("result content: %v", out.Content)
}

type userInfoRequest struct {
Name string `json:"name"`
Email string `json:"email"`
}

type userInfoResponse struct {
Name string `json:"name"`
Email string `json:"email"`
Company string `json:"company"`
Position string `json:"position"`
Salary string `json:"salary"`
}
129 changes: 129 additions & 0 deletions compose/graph/two_model_chat/two_model_chat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package main

import (
"context"
"errors"
"fmt"
"io"
"os"

"github.com/cloudwego/eino-ext/components/model/openai"
callbacks2 "github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"github.com/cloudwego/eino/utils/callbacks"

"github.com/cloudwego/eino-examples/internal/gptr"
)

func main() {
openAIBaseURL := os.Getenv("OPENAI_BASE_URL")
openAIAPIKey := os.Getenv("OPENAI_API_KEY")
modelName := os.Getenv("MODEL_NAME")

ctx := context.Background()

modelConf := &openai.ChatModelConfig{
BaseURL: openAIBaseURL,
APIKey: openAIAPIKey,
ByAzure: true,
Model: modelName,
Temperature: gptr.Of(float32(0.7)),
APIVersion: "2024-06-01",
}

type state struct {
currentRound int
msgs []*schema.Message
}

llm, err := openai.NewChatModel(ctx, modelConf)
if err != nil {
panic(err)
}

g := compose.NewGraph[[]*schema.Message, *schema.Message](compose.WithGenLocalState(func(ctx context.Context) *state { return &state{} }))
_ = g.AddChatModelNode("writer", llm, compose.WithStatePreHandler[[]*schema.Message, *state](func(ctx context.Context, input []*schema.Message, state *state) ([]*schema.Message, error) {
state.currentRound++
state.msgs = append(state.msgs, input...)
input = append([]*schema.Message{schema.SystemMessage("you are a writer who writes jokes and revise it according to the critic's feedback. Prepend your joke with your name which is \"writer: \"")}, state.msgs...)
return input, nil
}), compose.WithNodeName("writer"))
_ = g.AddChatModelNode("critic", llm, compose.WithStatePreHandler[[]*schema.Message, *state](func(ctx context.Context, input []*schema.Message, state *state) ([]*schema.Message, error) {
state.msgs = append(state.msgs, input...)
input = append([]*schema.Message{schema.SystemMessage("you are a critic who ONLY gives feedback about jokes, emphasizing on funniness. Prepend your feedback with your name which is \"critic: \"")}, state.msgs...)
return input, nil
}), compose.WithNodeName("critic"))
_ = g.AddLambdaNode("toList1", compose.ToList[*schema.Message]())
_ = g.AddLambdaNode("toList2", compose.ToList[*schema.Message]())
_ = g.AddBranch("writer", compose.NewStreamGraphBranch(func(ctx context.Context, input *schema.StreamReader[*schema.Message]) (string, error) {
input.Close()

s, err := compose.GetState[*state](ctx)
if err != nil {
return "", err
}

if s.currentRound >= 3 {
return compose.END, nil
}

return "toList1", nil
}, map[string]bool{compose.END: true, "toList1": true}))
_ = g.AddEdge(compose.START, "writer")
_ = g.AddEdge("toList1", "critic")
_ = g.AddEdge("critic", "toList2")
_ = g.AddEdge("toList2", "writer")
runner, err := g.Compile(ctx)
if err != nil {
panic(err)
}

sResponse := &streamResponse{
ch: make(chan string),
}
go func() {
for m := range sResponse.ch {
fmt.Print(m)
}
}()
handler := callbacks.NewHandlerHelper().ChatModel(&callbacks.ModelCallbackHandler{
OnEndWithStreamOutput: sResponse.OnStreamStart,
}).Handler()

outStream, err := runner.Stream(ctx, []*schema.Message{schema.UserMessage("write a funny line about robot, in 20 words.")},
compose.WithCallbacks(handler))
if err != nil {
panic(err)
}
for {
_, err := outStream.Recv()
if err == io.EOF {
close(sResponse.ch)
break
}
}
}

type streamResponse struct {
ch chan string
}

func (s *streamResponse) OnStreamStart(ctx context.Context, runInfo *callbacks2.RunInfo, input *schema.StreamReader[*model.CallbackOutput]) context.Context {
defer input.Close()
s.ch <- "\n=======\n"
for {
frame, err := input.Recv()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
fmt.Printf("internal error: %s\n", err)
return ctx
}

s.ch <- frame.Message.Content
}
return ctx
}

0 comments on commit f3d1920

Please sign in to comment.