Skip to content

Commit

Permalink
add support to format input json as typescript function str
Browse files Browse the repository at this point in the history
  • Loading branch information
tybalex committed Apr 16, 2024
1 parent 2e33911 commit 49acdcb
Showing 1 changed file with 101 additions and 3 deletions.
104 changes: 101 additions & 3 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector<co
//


static std::string rubra_format_function_call_str(const std::vector<json> & functions, json & tool_name_map) {
static std::string rubra_format_python_function_call_str(const std::vector<json> & functions, json & tool_name_map) {
std::string final_str = "You have access to the following tools:\n";
printf("rubra_format_function_call_str parsing...\n");
json type_mapping = {
Expand Down Expand Up @@ -432,6 +432,104 @@ static std::string rubra_format_function_call_str(const std::vector<json> & func
return final_str;
}


// Helper function to join strings with a delimiter
static std::string helper_join(const std::vector<std::string>& elements, const std::string& delimiter) {
std::string result;
for (auto it = elements.begin(); it != elements.end(); ++it) {
if (!result.empty()) {
result += delimiter;
}
result += *it;
}
return result;
}

static std::string rubra_format_typescript_function_call_str(const std::vector<json> &functions, json &tool_name_map) {
std::string final_str = "You have access to the following tools:\n";
json type_mapping = {
{"string", "string"},
{"integer", "number"},
{"number", "number"},
{"float", "number"},
{"object", "any"},
{"array", "any[]"},
{"boolean", "boolean"},
{"null", "null"}
};

std::vector<std::string> function_definitions;
for (const auto &function : functions) {
const auto &spec = function.contains("function") ? function["function"] : function;
std::string func_name = spec.value("name", "");
if (func_name.find('-') != std::string::npos) {
const std::string origin_func_name = func_name;
std::replace(func_name.begin(), func_name.end(), '-', '_'); // replace "-" with "_" because - is invalid in typescript func name
tool_name_map[func_name] = origin_func_name;
}

const std::string description = spec.contains("description") ? spec["description"].get<std::string>() : "";
const auto& parameters = spec.contains("parameters") ? spec["parameters"].value("properties", json({})) : json({});
const auto& required_params = spec.contains("parameters") ? spec["parameters"].value("required", std::vector<std::string>()) : std::vector<std::string>();

std::vector<std::string> func_args;
std::string docstring = "/**\n * " + description + "\n";

for (auto it = parameters.begin(); it != parameters.end(); ++it) {
const std::string param = it.key();
const json& details = it.value();
std::string json_type = details["type"].get<std::string>();
std::string ts_type = type_mapping.value(json_type, "any");
std::string param_description = "";
if (details.count("description") > 0) {
param_description = details["description"]; // Assuming the description is the first element
}
if (details.count("enum") > 0) {
std::string enum_values;
for (const std::string val : details["enum"]) {
if (!enum_values.empty()) {
enum_values += " or ";
}
enum_values = enum_values+ "\"" + val + "\"";
}
if (details["enum"].size() == 1) {
param_description += " Only Acceptable value is: " + enum_values;
} else {
param_description += " Only Acceptable values are: " + enum_values;
}
}
if (param_description.empty()) {
param_description = "No description provided.";
}
if (details.contains("enum")) {
ts_type = "string"; // Enum is treated as string in typescript
}
std::string arg_str = param + ": " + ts_type;
if (find(required_params.begin(), required_params.end(), param) == required_params.end()) {
arg_str = param + "?: " + ts_type;
docstring += " * @param " + param + " - " + param_description + "\n";
} else {
docstring += " * @param " + param + " - " + param_description + "\n";
}
func_args.push_back(arg_str);
}
docstring += " */\n";

std::string func_args_str = helper_join(func_args, ", ");
std::string function_definition = docstring + "function " + func_name + "(" + func_args_str + "): any {}";

function_definitions.push_back(function_definition);
}

for (const auto& def : function_definitions) {
final_str += def + "\n\n";
}
final_str += "Use the following format if using tools:\n<<functions>>[toolname1(arg1=value1, arg2=value2, ...), toolname2(arg1=value1, arg2=value2, ...)]";
return final_str;
}



static std::string default_tool_formatter(const std::vector<json>& tools) {
std::string toolText = "";
std::vector<std::string> toolNames;
Expand Down Expand Up @@ -493,12 +591,12 @@ static json oaicompat_completion_params_parse(

if (body.contains("tools") && !body["tools"].empty()) {
// function_str = default_tool_formatter(body["tool"]);
function_str = rubra_format_function_call_str(body["tools"], tool_name_map);
function_str = rubra_format_typescript_function_call_str(body["tools"], tool_name_map);
}
// If 'tool' is not set or empty, check 'functions'
else if (body.contains("functions") && !body["functions"].empty()) {
// function_str = default_tool_formatter(body["functions"]);
function_str = rubra_format_function_call_str(body["functions"], tool_name_map);
function_str = rubra_format_typescript_function_call_str(body["functions"], tool_name_map);
}
printf("\n=============Formatting Input from OPENAI format...============\n");
if (function_str != "") {
Expand Down

0 comments on commit 49acdcb

Please sign in to comment.