generated from cloudwego/.github
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add too_call_once and two_model_chat to graph examples
Change-Id: I4bc86918954cae88cc8f5b4fb4deea0d0895bfc7
- Loading branch information
1 parent
ff3a7a6
commit f3d1920
Showing
2 changed files
with
288 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
package main | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"os" | ||
|
||
"github.com/cloudwego/eino-ext/components/model/openai" | ||
"github.com/cloudwego/eino/components/prompt" | ||
"github.com/cloudwego/eino/components/tool" | ||
"github.com/cloudwego/eino/components/tool/utils" | ||
. "github.com/cloudwego/eino/compose" | ||
"github.com/cloudwego/eino/schema" | ||
|
||
"github.com/cloudwego/eino-examples/internal/gptr" | ||
"github.com/cloudwego/eino-examples/internal/logs" | ||
) | ||
|
||
func main() { | ||
openAIBaseURL := os.Getenv("OPENAI_BASE_URL") | ||
openAIAPIKey := os.Getenv("OPENAI_API_KEY") | ||
modelName := os.Getenv("MODEL_NAME") | ||
|
||
ctx := context.Background() | ||
|
||
systemTpl := `你是一名房产经纪人,结合用户的薪酬和工作,使用 user_info API,为其提供相关的房产信息。邮箱是必须的` | ||
chatTpl := prompt.FromMessages(schema.FString, | ||
schema.SystemMessage(systemTpl), | ||
schema.MessagesPlaceholder("message_histories", true), | ||
schema.UserMessage("{query}"), | ||
) | ||
|
||
modelConf := &openai.ChatModelConfig{ | ||
BaseURL: openAIBaseURL, | ||
APIKey: openAIAPIKey, | ||
ByAzure: true, | ||
Model: modelName, | ||
Temperature: gptr.Of(float32(0.7)), | ||
APIVersion: "2024-06-01", | ||
} | ||
|
||
chatModel, err := openai.NewChatModel(ctx, modelConf) | ||
if err != nil { | ||
logs.Errorf("NewChatModel failed, err=%v", err) | ||
return | ||
} | ||
|
||
userInfoTool := utils.NewTool( | ||
&schema.ToolInfo{ | ||
Name: "user_info", | ||
Desc: "根据用户的姓名和邮箱,查询用户的公司、职位、薪酬信息", | ||
ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ | ||
"name": { | ||
Type: "string", | ||
Desc: "用户的姓名", | ||
}, | ||
"email": { | ||
Type: "string", | ||
Desc: "用户的邮箱", | ||
}, | ||
}), | ||
}, | ||
func(ctx context.Context, input *userInfoRequest) (output *userInfoResponse, err error) { | ||
return &userInfoResponse{ | ||
Name: input.Name, | ||
Email: input.Email, | ||
Company: "Bytedance", | ||
Position: "CEO", | ||
Salary: "9999", | ||
}, nil | ||
}) | ||
|
||
info, err := userInfoTool.Info(ctx) | ||
if err != nil { | ||
logs.Errorf("Get ToolInfo failed, err=%v", err) | ||
return | ||
} | ||
|
||
err = chatModel.BindTools([]*schema.ToolInfo{info}) | ||
if err != nil { | ||
logs.Errorf("BindTools failed, err=%v", err) | ||
return | ||
} | ||
|
||
toolsNode, err := NewToolNode(ctx, &ToolsNodeConfig{ | ||
Tools: []tool.BaseTool{userInfoTool}, | ||
}) | ||
if err != nil { | ||
logs.Errorf("NewToolNode failed, err=%v", err) | ||
return | ||
} | ||
|
||
takeOne := InvokableLambda(func(ctx context.Context, input []*schema.Message) (*schema.Message, error) { | ||
if len(input) == 0 { | ||
return nil, errors.New("input is empty") | ||
} | ||
return input[0], nil | ||
}) | ||
|
||
const ( | ||
nodeModel = "node_model" | ||
nodeTools = "node_tools" | ||
nodeTemplate = "node_template" | ||
nodeConverter = "node_converter" | ||
) | ||
|
||
branch := NewStreamGraphBranch(func(ctx context.Context, input *schema.StreamReader[*schema.Message]) (string, error) { | ||
defer input.Close() | ||
msg, err := input.Recv() | ||
if err != nil { | ||
return "", err | ||
} | ||
|
||
if len(msg.ToolCalls) > 0 { | ||
return "tools", nil | ||
} | ||
|
||
return END, nil | ||
}, map[string]bool{END: true, nodeTools: true}) | ||
|
||
graph := NewGraph[map[string]any, *schema.Message]() | ||
|
||
_ = graph.AddChatTemplateNode(nodeTemplate, chatTpl) | ||
_ = graph.AddChatModelNode(nodeModel, chatModel) | ||
_ = graph.AddToolsNode(nodeTools, toolsNode) | ||
_ = graph.AddLambdaNode(nodeConverter, takeOne) | ||
|
||
_ = graph.AddEdge(START, nodeTemplate) | ||
_ = graph.AddEdge(nodeTemplate, nodeModel) | ||
_ = graph.AddBranch(nodeModel, branch) | ||
_ = graph.AddEdge(nodeTools, nodeConverter) | ||
_ = graph.AddEdge(nodeConverter, END) | ||
|
||
r, err := graph.Compile(ctx) | ||
if err != nil { | ||
logs.Errorf("Compile failed, err=%v", err) | ||
} | ||
|
||
out, err := r.Invoke(ctx, map[string]any{"query": "我叫 zhangsan, 邮箱是 [email protected], 帮我推荐一处房产"}) | ||
if err != nil { | ||
logs.Errorf("Invoke failed, err=%v", err) | ||
return | ||
} | ||
|
||
logs.Infof("result content: %v", out.Content) | ||
} | ||
|
||
type userInfoRequest struct { | ||
Name string `json:"name"` | ||
Email string `json:"email"` | ||
} | ||
|
||
type userInfoResponse struct { | ||
Name string `json:"name"` | ||
Email string `json:"email"` | ||
Company string `json:"company"` | ||
Position string `json:"position"` | ||
Salary string `json:"salary"` | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
package main | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"io" | ||
"os" | ||
|
||
"github.com/cloudwego/eino-ext/components/model/openai" | ||
callbacks2 "github.com/cloudwego/eino/callbacks" | ||
"github.com/cloudwego/eino/components/model" | ||
"github.com/cloudwego/eino/compose" | ||
"github.com/cloudwego/eino/schema" | ||
"github.com/cloudwego/eino/utils/callbacks" | ||
|
||
"github.com/cloudwego/eino-examples/internal/gptr" | ||
) | ||
|
||
func main() { | ||
openAIBaseURL := os.Getenv("OPENAI_BASE_URL") | ||
openAIAPIKey := os.Getenv("OPENAI_API_KEY") | ||
modelName := os.Getenv("MODEL_NAME") | ||
|
||
ctx := context.Background() | ||
|
||
modelConf := &openai.ChatModelConfig{ | ||
BaseURL: openAIBaseURL, | ||
APIKey: openAIAPIKey, | ||
ByAzure: true, | ||
Model: modelName, | ||
Temperature: gptr.Of(float32(0.7)), | ||
APIVersion: "2024-06-01", | ||
} | ||
|
||
type state struct { | ||
currentRound int | ||
msgs []*schema.Message | ||
} | ||
|
||
llm, err := openai.NewChatModel(ctx, modelConf) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
g := compose.NewGraph[[]*schema.Message, *schema.Message](compose.WithGenLocalState(func(ctx context.Context) *state { return &state{} })) | ||
_ = g.AddChatModelNode("writer", llm, compose.WithStatePreHandler[[]*schema.Message, *state](func(ctx context.Context, input []*schema.Message, state *state) ([]*schema.Message, error) { | ||
state.currentRound++ | ||
state.msgs = append(state.msgs, input...) | ||
input = append([]*schema.Message{schema.SystemMessage("you are a writer who writes jokes and revise it according to the critic's feedback. Prepend your joke with your name which is \"writer: \"")}, state.msgs...) | ||
return input, nil | ||
}), compose.WithNodeName("writer")) | ||
_ = g.AddChatModelNode("critic", llm, compose.WithStatePreHandler[[]*schema.Message, *state](func(ctx context.Context, input []*schema.Message, state *state) ([]*schema.Message, error) { | ||
state.msgs = append(state.msgs, input...) | ||
input = append([]*schema.Message{schema.SystemMessage("you are a critic who ONLY gives feedback about jokes, emphasizing on funniness. Prepend your feedback with your name which is \"critic: \"")}, state.msgs...) | ||
return input, nil | ||
}), compose.WithNodeName("critic")) | ||
_ = g.AddLambdaNode("toList1", compose.ToList[*schema.Message]()) | ||
_ = g.AddLambdaNode("toList2", compose.ToList[*schema.Message]()) | ||
_ = g.AddBranch("writer", compose.NewStreamGraphBranch(func(ctx context.Context, input *schema.StreamReader[*schema.Message]) (string, error) { | ||
input.Close() | ||
|
||
s, err := compose.GetState[*state](ctx) | ||
if err != nil { | ||
return "", err | ||
} | ||
|
||
if s.currentRound >= 3 { | ||
return compose.END, nil | ||
} | ||
|
||
return "toList1", nil | ||
}, map[string]bool{compose.END: true, "toList1": true})) | ||
_ = g.AddEdge(compose.START, "writer") | ||
_ = g.AddEdge("toList1", "critic") | ||
_ = g.AddEdge("critic", "toList2") | ||
_ = g.AddEdge("toList2", "writer") | ||
runner, err := g.Compile(ctx) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
sResponse := &streamResponse{ | ||
ch: make(chan string), | ||
} | ||
go func() { | ||
for m := range sResponse.ch { | ||
fmt.Print(m) | ||
} | ||
}() | ||
handler := callbacks.NewHandlerHelper().ChatModel(&callbacks.ModelCallbackHandler{ | ||
OnEndWithStreamOutput: sResponse.OnStreamStart, | ||
}).Handler() | ||
|
||
outStream, err := runner.Stream(ctx, []*schema.Message{schema.UserMessage("write a funny line about robot, in 20 words.")}, | ||
compose.WithCallbacks(handler)) | ||
if err != nil { | ||
panic(err) | ||
} | ||
for { | ||
_, err := outStream.Recv() | ||
if err == io.EOF { | ||
close(sResponse.ch) | ||
break | ||
} | ||
} | ||
} | ||
|
||
type streamResponse struct { | ||
ch chan string | ||
} | ||
|
||
func (s *streamResponse) OnStreamStart(ctx context.Context, runInfo *callbacks2.RunInfo, input *schema.StreamReader[*model.CallbackOutput]) context.Context { | ||
defer input.Close() | ||
s.ch <- "\n=======\n" | ||
for { | ||
frame, err := input.Recv() | ||
if errors.Is(err, io.EOF) { | ||
break | ||
} | ||
if err != nil { | ||
fmt.Printf("internal error: %s\n", err) | ||
return ctx | ||
} | ||
|
||
s.ch <- frame.Message.Content | ||
} | ||
return ctx | ||
} |