Skip to content

Commit

Permalink
update function.cpp to support typescript functiin description + json…
Browse files Browse the repository at this point in the history
… function in/out
  • Loading branch information
tybalex committed May 10, 2024
1 parent 6d88574 commit 1c2ab9d
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 85 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

server: examples/server/server.cpp examples/server/utils.hpp examples/server/python-parser.hpp examples/server/yaml-parser.hpp examples/server/function-call.hpp examples/server/tree_sitter/libtree-sitter.a examples/server/yaml-cpp/libyaml-cpp.a examples/server/httplib.h common/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/server/json-schema-to-grammar.mjs.hpp common/stb_image.h ggml.o llama.o func_scanner.o func_parser.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
server: examples/server/server.cpp examples/server/utils.hpp examples/server/python-parser.hpp examples/server/function-call-parser.hpp examples/server/function-call.hpp examples/server/tree_sitter/libtree-sitter.a examples/server/yaml-cpp/libyaml-cpp.a examples/server/httplib.h common/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/server/json-schema-to-grammar.mjs.hpp common/stb_image.h ggml.o llama.o func_scanner.o func_parser.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -I examples/server/tree_sitter -I examples/server/yaml-cpp -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) $(call GET_OBJ_FILE, $<) -Iexamples/server -o $@ $(LDFLAGS) $(LWINSOCK2)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include <yaml-cpp/yaml.h>

using json = nlohmann::ordered_json;

using namespace std;

static json get_value(string target_arg_type, const YAML::Node & node) {
json arg_json;
Expand Down Expand Up @@ -94,3 +94,52 @@ std::vector<json> rubra_fc_yaml_tool_extractor(const std::string& source_string,
return calls;
}


std::string generate_uuid() {
static std::random_device rd;
static std::mt19937 generator(rd());
static std::uniform_int_distribution<int> distribution(0, 15);

const char *v = "0123456789abcdef";
std::stringstream uuid;

for (int i = 0; i < 8; ++i) {
uuid << v[distribution(generator)];
}
return uuid.str();
}

std::vector<json> rubra_fc_json_tool_extractor(const std::string& output_str) {
std::vector<json> result;
if (output_str.find("<functions>") == std::string::npos) {
return result;
}

std::string str_to_parse = output_str.substr(output_str.find("<functions>") + std::string("<functions>").length());
std::istringstream stream(str_to_parse);
std::string line;
std::vector<json> function_call_json;

try {
while (std::getline(stream, line)) {
json fc = json::parse(line);
if (fc["arguments"].is_string()) {
fc["arguments"] = json::parse(fc["arguments"].get<std::string>());
}
function_call_json.push_back(fc);
}
} catch (const std::exception& e) {
std::cerr << "Error: " << e.what() << std::endl;
}

for (const auto& fc : function_call_json) {
json func_call;
func_call["id"] = generate_uuid();
func_call["name"] = fc["name"];
func_call["kwargs"] = fc["arguments"];
func_call["type"] = "function";
result.push_back(func_call);
}

return result;
}
33 changes: 27 additions & 6 deletions examples/server/function-call.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ string rubra_format_typescript_function_call_str(const std::vector<json> &functi
for (const auto& def : function_definitions) {
final_str += def + "\n\n";
}
final_str += "Use the following format if using a tool:\n<<functions>>[toolname1(arg1=value1, arg2=value2, ...), toolname2(arg1=value1, arg2=value2, ...)]\nYou can choose to respond with 1 or more tool calls at once, or with a chat message back to the user. Only make tool calls once you have all the details to fill in the required params. Feel free to ask the user for more info when appropriate. Any tool call you make must match the name of a function(s) provided above.";
final_str += "You can choose to respond with one or more tool calls at once, or with a chat message back to the user. Ensure you have all necessary details before making tool calls. If additional information is needed, ask the user appropriately. Any tool call you make must correspond to the functions listed above.\nIf you decide to call tools, format your response in JSONL. Start with the keyword `<functions>` followed by the JSON object:\n`<functions>{\"name\": \"<function_name>\", \"arguments\": {\"<arg1_name>\": \"<arg1_value>\", \"<arg2_name>\": \"<arg2_value>\", ...}}`";
return final_str;

}
Expand Down Expand Up @@ -546,17 +546,38 @@ static string construct_yaml_tool_call_str(const json & tool_calls, nlohmann::or
}


std::string construct_json_tool_call_str(const json& tool_calls, nlohmann::ordered_map<std::string, std::string> & func_observation_map) {
std::string tool_call_str;
bool first = true;
for (const auto& tool_call : tool_calls) {
std::string tool_call_id = tool_call["id"];
func_observation_map[tool_call_id] = ""; // Initialize with empty value, updated later from the message with tool role

if (!first) {
tool_call_str += "\n";
}
json tc = tool_call["function"];
if (tc["arguments"].is_string()) {
tc["arguments"] = json::parse(tc["arguments"].get<std::string>());
}
tool_call_str += tc.dump();
first = false;
}

return std::string("<functions>") + tool_call_str;
}


const vector<json> expand_messages(const json & body, json &tool_name_map, json& func_arg_type) {
string function_str = "";
if (body.contains("tools") && !body["tools"].empty()) {
// function_str = rubra_format_typescript_function_call_str(body["tools"], tool_name_map);
function_str = rubra_format_yaml_function_call_str(body["tools"], func_arg_type);
function_str = rubra_format_typescript_function_call_str(body["tools"], tool_name_map);
// function_str = rubra_format_yaml_function_call_str(body["tools"], func_arg_type);
}
// If 'tool' is not set or empty, check 'functions'
else if (body.contains("functions") && !body["functions"].empty()) {
// function_str = rubra_format_typescript_function_call_str(body["functions"], tool_name_map);
function_str = rubra_format_yaml_function_call_str(body["functions"], func_arg_type);
function_str = rubra_format_typescript_function_call_str(body["functions"], tool_name_map);
// function_str = rubra_format_yaml_function_call_str(body["functions"], func_arg_type);
}

if (function_str != "") {
Expand Down Expand Up @@ -601,7 +622,7 @@ const vector<json> expand_messages(const json & body, json &tool_name_map, json&
else if (body["messages"][i]["role"] == "assistant" and body["messages"][i].contains("tool_calls")){
// convert OpenAI function call format to Rubra format
// string tool_call_str = construct_python_tool_call_str(body["messages"][i]["tool_calls"], func_observation_map);
string tool_call_str = construct_yaml_tool_call_str(body["messages"][i]["tool_calls"], func_observation_map);
string tool_call_str = construct_json_tool_call_str(body["messages"][i]["tool_calls"], func_observation_map);
json function_call;
function_call["role"] = "function";
function_call["content"] = tool_call_str;
Expand Down
10 changes: 7 additions & 3 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "json.hpp"
#include "python-parser.hpp"
#include "function-call.hpp"
#include "yaml-parser.hpp"
#include "function-call-parser.hpp"

#include <string>
#include <vector>
Expand Down Expand Up @@ -448,7 +448,8 @@ static json format_final_response_oaicompat(const json & request, json result, c
std::string content = json_value(result, "content", std::string(""));

// std::vector<json> parsed_content = parsePythonFunctionCalls(content, request["tool_name_map"]);
std::vector<json> parsed_content = rubra_fc_yaml_tool_extractor(content, request["func_arg_type"]);
// std::vector<json> parsed_content = rubra_fc_yaml_tool_extractor(content, request["func_arg_type"]);
std::vector<json> parsed_content = rubra_fc_json_tool_extractor(content);

std::string finish_reason = "length";
if (stopped_word || stopped_eos) {
Expand All @@ -475,6 +476,7 @@ static json format_final_response_oaicompat(const json & request, json result, c
json tool_call;
tool_call["id"] = pc["id"];
tool_call["type"] = "function";

tool_call["function"] = json{
{"name" , pc["name"]},
{"arguments" , pc["kwargs"].dump()},
Expand Down Expand Up @@ -531,7 +533,9 @@ static std::vector<json> format_partial_response_oaicompat(json request ,json re
std::string content = json_value(result, "content", std::string(""));

// std::vector<json> parsed_content = parsePythonFunctionCalls(content, request["tool_name_map"]);
std::vector<json> parsed_content = rubra_fc_yaml_tool_extractor(content, request["func_arg_type"]);
// std::vector<json> parsed_content = rubra_fc_yaml_tool_extractor(content, request["func_arg_type"]);
std::vector<json> parsed_content = rubra_fc_json_tool_extractor(content);

std::time_t t = std::time(0);
if (!parsed_content.empty()) {
std::vector<json> res;
Expand Down
Loading

0 comments on commit 1c2ab9d

Please sign in to comment.