Skip to content

Commit

Permalink
AppbuilderClient.Run support ToolChoice && add end_user_id (baidubce#494
Browse files Browse the repository at this point in the history
)

* AppbuilderClient support ToolChoice

* add go&&java

* update

* AppbuilderClient.run support end_user_id

* fix bug

* update

* update

* update

* update
  • Loading branch information
userpj authored Sep 11, 2024
1 parent e7fa14f commit 2aa73f0
Showing 4 changed files with 114 additions and 16 deletions.
32 changes: 19 additions & 13 deletions appbuilder/core/console/appbuilder_client/appbuilder_client.py
Original file line number Diff line number Diff line change
@@ -169,7 +169,7 @@ def upload_local_file(self, conversation_id, local_file_path: str) -> str:

if len(conversation_id) == 0:
raise ValueError("conversation_id is empty, you can run self.create_conversation to get a conversation_id")

filepath = os.path.abspath(local_file_path)
if not os.path.exists(filepath):
raise FileNotFoundError(f"{filepath} does not exist")
@@ -196,24 +196,28 @@ def run(self, conversation_id: str,
stream: bool = False,
tools: list[data_class.Tool] = None,
tool_outputs: list[data_class.ToolOutput] = None,
tool_choice: data_class.ToolChoice = None,
end_user_id: str = None,
**kwargs
) -> Message:
r"""
参数:
query (str: 必须): query内容
conversation_id (str, 必须): 唯一会话ID,如需开始新的会话,请使用self.create_conversation创建新的会话
file_ids(list[str], 可选):
stream (bool, 可选): 为True时,流式返回,需要将message.content.answer拼接起来才是完整的回答;为False时,对应非流式返回
tools(list[data_class.Tools], 可选): 一个Tools组成的列表,其中每个Tools对应一个工具的配置, 默认为None
tool_outputs(list[data_class.ToolOutput], 可选): 工具输出列表,格式为list[ToolOutput], ToolOutputd内容为本地的工具执行结果,以自然语言/json dump str描述,默认为None
返回: message (obj: `Message`): 对话结果.
参数:
query (str: 必须): query内容
conversation_id (str, 必须): 唯一会话ID,如需开始新的会话,请使用self.create_conversation创建新的会话
file_ids(list[str], 可选):
stream (bool, 可选): 为True时,流式返回,需要将message.content.answer拼接起来才是完整的回答;为False时,对应非流式返回
tools(list[data_class.Tools], 可选): 一个Tools组成的列表,其中每个Tools对应一个工具的配置, 默认为None
tool_outputs(list[data_class.ToolOutput], 可选): 工具输出列表,格式为list[ToolOutput], ToolOutputd内容为本地的工具执行结果,以自然语言/json dump str描述,默认为None
tool_choice(data_class.ToolChoice, 可选): 控制大模型使用组件的方式,默认为None
end_user_id (str, 可选): 用户ID,用于区分不同用户
返回: message (obj: `Message`): 对话结果.
"""

if len(conversation_id) == 0:
raise ValueError(
"conversation_id is empty, you can run self.create_conversation to get a conversation_id"
)

if query == "" and (tool_outputs is None or len(tool_outputs) == 0):
raise ValueError("AppBuilderClient Run API: query and tool_outputs cannot both be empty")

@@ -224,7 +228,9 @@ def run(self, conversation_id: str,
stream=True if stream else False,
file_ids=file_ids,
tools=tools,
tool_outputs=tool_outputs
tool_outputs=tool_outputs,
tool_choice=tool_choice,
end_user_id=end_user_id,
)

headers = self.http_client.auth_header_v2()
@@ -244,7 +250,7 @@ def run(self, conversation_id: str,
out = data_class.AppBuilderClientAnswer()
_transform(resp, out)
return Message(content=out)

def run_with_handler(self,
conversation_id: str,
query: str = "",
@@ -263,7 +269,7 @@ def run_with_handler(self,
stream=stream,
**kwargs
)

return event_handler

@staticmethod
29 changes: 26 additions & 3 deletions appbuilder/core/console/appbuilder_client/data_class.py
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@ class Function(BaseModel):
name: str = Field(..., description="工具名称")
description: str = Field(..., description="工具描述")
parameters: dict = Field(..., description="工具参数, json_schema格式")

class Tool(BaseModel):
type: str = "function"
function: Function = Field(..., description="工具信息")
@@ -40,6 +40,28 @@ class ToolCall(BaseModel):
type: str = Field("function", description="需要输出的工具调用的类型。就目前而言,这始终是function")
function: FunctionCallDetail = Field(..., description="函数定义")


class ToolChoiceFunction(BaseModel):
name: str = Field(
...,
description="组件的英文名称(唯一标识),用户通过工作流完成自定义组件后,可在个人空间-组件下查看组件英文名称",
)
input: dict = Field(
...,
description="当组件没有入参或者必填的入参只有一个时可省略,必填的入参只有一个且省略时,使用query字段的值作为入参",
)


class ToolChoice(BaseModel):
type: str = Field(
...,
description="auto/function,auto表示由LLM自动判断调什么组件;function表示由用户指定调用哪个组件",
)
function: Optional[ToolChoiceFunction] = Field(
..., description="当type为function时,需要指定调用哪个组件"
)


class AppBuilderClientRequest(BaseModel):
"""会话请求参数
属性:
@@ -56,6 +78,8 @@ class AppBuilderClientRequest(BaseModel):
app_id: str
tools: Optional[list[Tool]] = None
tool_outputs: Optional[list[ToolOutput]] = None
tool_choice: Optional[ToolChoice] = None
end_user_id: Optional[str] = None


class Usage(BaseModel):
@@ -107,7 +131,7 @@ class AppBuilderClientResponse(BaseModel):
message_id: str = ""
is_completion: Optional[bool] = False
content: list[OriginalEvent] = []


class TextDetail(BaseModel):
"""content_type=text,详情内容
@@ -286,4 +310,3 @@ class AppBuilderClientAppListResponse(BaseModel):
request_id: str = Field("", description="请求ID")
data: Optional[list[AppOverview]] = Field(
[], description="应用概览列表")

12 changes: 12 additions & 0 deletions go/appbuilder/app_builder_client_data.go
Original file line number Diff line number Diff line change
@@ -48,9 +48,11 @@ type AppBuilderClientRunRequest struct {
AppID string `json:"app_id"`
Query string `json:"query"`
Stream bool `json:"stream"`
EndUserID string `json:"end_user_id"`
ConversationID string `json:"conversation_id"`
Tools []Tool `json:"tools"`
ToolOutputs []ToolOutput `json:"tool_outputs"`
ToolChoice ToolChoice `json:"tool_choice"`
}

type Tool struct {
@@ -69,6 +71,16 @@ type ToolOutput struct {
Output string `json:"output" description:"工具输出"`
}

type ToolChoice struct {
Type string `json:"type"`
Function ToolChoiceFunction `json:"function"`
}

type ToolChoiceFunction struct {
Name string `json:"name"`
Input map[string]interface{} `json:"input"`
}

type AgentBuilderRawResponse struct {
RequestID string `json:"request_id"`
Date string `json:"date"`
Original file line number Diff line number Diff line change
@@ -10,9 +10,13 @@ public class AppBuilderClientRunRequest {
private boolean stream;
@SerializedName("conversation_id")
private String conversationID;
@SerializedName("end_user_id")
private String endUserId;
private Tool[] tools;
@SerializedName("tool_outputs")
private ToolOutput[] ToolOutputs;
@SerializedName("tool_choice")
private ToolChoice ToolChoice;

public String getAppId() {
return appId;
@@ -46,6 +50,14 @@ public void setConversationID(String conversationID) {
this.conversationID = conversationID;
}

public String getEndUserId() {
return endUserId;
}

public void setEndUserId(String endUserId) {
this.endUserId = endUserId;
}

public Tool[] getTools() {
return tools;
}
@@ -62,6 +74,14 @@ public void setToolOutputs(ToolOutput[] toolOutputs) {
this.ToolOutputs = toolOutputs;
}

public ToolChoice getToolChoice() {
return ToolChoice;
}

public void setToolChoice(ToolChoice toolChoice) {
this.ToolChoice = toolChoice;
}

public static class Tool {
private String type;
private Function function;
@@ -79,6 +99,7 @@ public Function getFunction() {
return function;
}


public static class Function {
private String name;
private String description;
@@ -122,4 +143,40 @@ public String getOutput() {
return output;
}
}

public static class ToolChoice {
private String type;
private Function function;

public ToolChoice(String type, Function function) {
this.type=type;
this.function=function;
}

public String getType() {
return type;
}

public Function getFunction() {
return function;
}

public static class Function {
private String name;
private Map<String, Object> input;

public Function(String name, Map<String, Object> input) {
this.name = name;
this.input = input;
}

public String getName() {
return name;
}

public Map<String, Object> getInput() {
return input;
}
}
}
}

0 comments on commit 2aa73f0

Please sign in to comment.