Skip to content

Commit

Permalink
fix the chat template in llama.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
tybalex committed Apr 18, 2024
1 parent 71c3d33 commit 4c238e6
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 42 deletions.
10 changes: 5 additions & 5 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16870,7 +16870,7 @@ static int32_t llama_chat_apply_template_internal(
bool strip_message = tmpl.find("content.strip()") != std::string::npos;
// construct the prompt
bool is_inside_turn = true; // skip BOS at the beginning
ss << "[INST] ";
// ss << "[INST] ";
for (auto message : chat) {
std::string content = strip_message ? trim(message->content) : message->content;
std::string role(message->role);
Expand All @@ -16883,13 +16883,13 @@ static int32_t llama_chat_apply_template_internal(
ss << "<<SYS>>\n" << content << "\n<</SYS>>\n\n";
} else {
// if the model does not support system message, we still include it in the first message, but without <<SYS>>
ss << content << "\n";
ss << "<s>" << content << "\n";
}
} else if (role == "user") {
ss << content << " [/INST]";
} else if (role == "user" or role == "observation") {
ss << "[INST]" << content << " [/INST]";
} else {
ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << "</s>";
is_inside_turn = false;
// is_inside_turn = false;
}
}
// llama2 templates seem to not care about "add_generation_prompt"
Expand Down
80 changes: 43 additions & 37 deletions test_llamacpp.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
},
{
"cell_type": "code",
"execution_count": 62,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -324,7 +324,7 @@
},
{
"cell_type": "code",
"execution_count": 55,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -335,7 +335,7 @@
},
{
"cell_type": "code",
"execution_count": 56,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand All @@ -346,26 +346,18 @@
"\n",
"[AI calling functions]:\n",
"Tool Call: Function(arguments='{\"origin\":\"San Francisco\",\"destination\":\"Cupertino\",\"mode\":\"driving\"}', name='calculate_distance')\n",
"Pointing to URL: http://localhost:8019/v1/\n",
"Loop 0\n",
"\n",
"[AI calling functions]:\n",
"Tool Call: Function(arguments='{\"origin\":\"Cupertino\",\"destination\":\"San Francisco\",\"mode\":\"driving\"}', name='calculate_distance')\n",
"Pointing to URL: http://localhost:8019/v1/\n",
"Loop 1\n",
"\n",
"[AI calling functions]:\n",
"Tool Call: Function(arguments='{\"origin\":\"San Francisco\",\"destination\":\"Cupertino\",\"mode\":\"air\"}', name='calculate_distance')\n",
"Pointing to URL: http://localhost:8019/v1/\n",
"Loop 2\n",
"\n",
"[AI calling functions]:\n",
"Tool Call: Function(arguments='{\"origin\":\"Cupertino\",\"destination\":\"San Francisco\",\"mode\":\"air\"}', name='calculate_distance')\n",
"Observation: Distance is 50 miles.\n",
"Observation: Distance is 100 miles.\n",
"Observation: Distance is 150 miles.\n",
"Observation: Distance is 200 miles.\n",
"Pointing to URL: http://localhost:8019/v1/\n",
"Loop 3\n",
"Loop 0\n",
"\n",
"[AI response]:\n",
" The distance between San Francisco and Cupertino is 50 miles when driving and 50 miles when flying, regardless of the direction.\n"
" The distance between San Francisco and Cupertino is 50 miles by driving and 100 miles by air from San Francisco to Cupertino. From Cupertino to San Francisco, it is 100 miles by driving and 200 miles by air.\n"
]
}
],
Expand All @@ -381,7 +373,7 @@
},
{
"cell_type": "code",
"execution_count": 57,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand All @@ -392,16 +384,18 @@
"\n",
"[AI calling functions]:\n",
"Tool Call: Function(arguments='{\"number_to_buy\":3}', name='orderUmbrella')\n",
"Observation: Order placed. the price is 10 dollars.\n",
"Pointing to URL: http://localhost:8019/v1/\n",
"Loop 0\n",
"\n",
"[AI calling functions]:\n",
"Tool Call: Function(arguments='{\"length\":8}', name='generate_password')\n",
"Observation: Password generated: 51c83034\n",
"Pointing to URL: http://localhost:8019/v1/\n",
"Loop 1\n",
"\n",
"[AI response]:\n",
" Your order for 3 umbrellas has been placed successfully. The total cost is $10. Additionally, a password of length 8 has been generated for you: b735993b.\n"
" Your order for 3 umbrellas has been placed successfully. The total price is $10. Additionally, a password of length 8 has been generated for you: 51c83034.\n"
]
}
],
Expand All @@ -419,7 +413,7 @@
},
{
"cell_type": "code",
"execution_count": 58,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand All @@ -429,27 +423,39 @@
"Pointing to URL: http://localhost:8019/v1/\n",
"\n",
"[AI calling functions]:\n",
"Tool Call: Function(arguments='{\"directory\":\"documents\"}', name='list_files')\n",
"Tool Call: Function(arguments='{\"a\":\"4\",\"b\":\"6\"}', name='addition')\n",
"Observation: 10.0\n",
"Pointing to URL: http://localhost:8019/v1/\n",
"Loop 0\n",
"\n",
"[AI calling functions]:\n",
"Tool Call: Function(arguments='{\"filename\":\"report.docx\"}', name='get_file_size')\n",
"Tool Call: Function(arguments='{\"filename\":\"task.txt\"}', name='get_file_size')\n",
"Tool Call: Function(arguments='{\"filename\":\"notes.txt\"}', name='get_file_size')\n",
"Tool Call: Function(arguments='{\"a\":\"10\",\"b\":\"2\"}', name='addition')\n",
"Observation: 12.0\n",
"Pointing to URL: http://localhost:8019/v1/\n",
"Loop 1\n",
"\n",
"[AI calling functions]:\n",
"Tool Call: Function(arguments='{\"a\":\"12\",\"b\":\"5\"}', name='multiplication')\n",
"Observation: 60.0\n",
"Pointing to URL: http://localhost:8019/v1/\n",
"Loop 2\n",
"\n",
"[AI calling functions]:\n",
"Tool Call: Function(arguments='{\"a\":\"60\",\"b\":\"2\"}', name='division')\n",
"Observation: 30.0\n",
"Pointing to URL: http://localhost:8019/v1/\n",
"Loop 3\n",
"\n",
"[AI response]:\n",
" The sizes of the files in the 'documents' directory are as follows: report.docx is 100 bytes, task.txt is 200 bytes, and notes.txt is 300 bytes.\n"
" The result of 4 plus 6 is 10. Adding 2 to that gives 12. Multiplying 12 by 5 gives 60, and dividing 60 by 2 gives 30.\n"
]
}
],
"source": [
"user_query3 = \"User tool to help me : What is four plus six? What is the result of that plus 2? Take the result and multiply by 5 and then divide by two\"\n",
"\n",
"\n",
"msgs = run_completion(get_mistral_rubra_response, user_query3, msgs)"
"msgs = run_completion(get_mistral_rubra_response, user_query3)"
]
},
{
Expand All @@ -461,7 +467,7 @@
},
{
"cell_type": "code",
"execution_count": 63,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand All @@ -484,18 +490,18 @@
"Loop 1\n",
"\n",
"[AI response]:\n",
" The weather in Boston is 60 degrees Fahrenheit. The distance from Boston to New York City is 50 miles.\n"
" The temperature in Boston is 60 degrees Fahrenheit. The distance from Boston to New York City is 50 miles.\n"
]
}
],
"source": [
"user_query4 = \"check the weather in boston, calculate the distance from boston to NYC for me only if it's less than 100 degrees Fahrenheit\"\n",
"msgs = run_completion(get_mistral_rubra_response, user_query4, msgs)"
"msgs = run_completion(get_mistral_rubra_response, user_query4)"
]
},
{
"cell_type": "code",
"execution_count": 64,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand All @@ -505,26 +511,26 @@
"Pointing to URL: http://localhost:8019/v1/\n",
"\n",
"[AI calling functions]:\n",
"Tool Call: Function(arguments='{\"location\":\"Boston\",\"unit\":\"f\"}', name='getCurrentWeather')\n",
"Tool Call: Function(arguments='{\"location\":\"Boston, MA\",\"unit\":\"f\"}', name='getCurrentWeather')\n",
"\n",
"Observation: temprature is 56 degree\n",
"Observation: temprature is 60 degree\n",
"Pointing to URL: http://localhost:8019/v1/\n",
"Loop 0\n",
"\n",
"[AI calling functions]:\n",
"Tool Call: Function(arguments='{\"origin\":\"Boston\",\"destination\":\"New York\",\"mode\":\"car\"}', name='calculate_distance')\n",
"Tool Call: Function(arguments='{\"origin\":\"Boston, MA\",\"destination\":\"New York City, NY\",\"mode\":\"car\"}', name='calculate_distance')\n",
"Observation: Distance is 50 miles.\n",
"Pointing to URL: http://localhost:8019/v1/\n",
"Loop 1\n",
"\n",
"[AI response]:\n",
" The temperature in Boston is 56 degrees Fahrenheit. The distance from Boston to New York is 50 miles.\n"
" The temperature in Boston is 60 degrees Fahrenheit. The distance from Boston to New York City is 50 miles.\n"
]
}
],
"source": [
"user_query5 = \"check the weather in boston, calculate the distance from boston to NYC for me only if it's greater than 100 degrees Fahrenheit\"\n",
"msgs = run_completion(get_mistral_rubra_response, user_query5, msgs)"
"msgs = run_completion(get_mistral_rubra_response, user_query5)"
]
},
{
Expand All @@ -536,7 +542,7 @@
},
{
"cell_type": "code",
"execution_count": 65,
"execution_count": 8,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -571,7 +577,7 @@
],
"source": [
"user_query6 = \"check the size of all files in the 'documents' directory.\"\n",
"msgs = run_completion(get_mistral_rubra_response, user_query6, msgs)"
"msgs = run_completion(get_mistral_rubra_response, user_query6)"
]
}
],
Expand Down

0 comments on commit 4c238e6

Please sign in to comment.