Skip to content

Commit

Permalink
actually take advantadge of using namespace std;
Browse files Browse the repository at this point in the history
  • Loading branch information
arrufat committed Jan 23, 2025
1 parent 4d116ab commit 1e09fc1
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions examples/slm_basic_train_ex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ int main(int argc, char** argv)
// ----------------------------------------------------------------------------------------
if (parser.option("train"))
{
std::cout << "=== TRAIN MODE ===\n";
cout << "=== TRAIN MODE ===\n";

// 1) Prepare training data (simple approach)
// We will store characters from shakespeare_text into a vector
Expand All @@ -163,7 +163,7 @@ int main(int argc, char** argv)
auto full_tokens = char_based_tokenize(shakespeare_text);
if (full_tokens.empty())
{
std::cerr << "ERROR: The Shakespeare text is empty. Please provide a valid training text.\n";
cerr << "ERROR: The Shakespeare text is empty. Please provide a valid training text.\n";
return 0;
}

Expand All @@ -173,13 +173,13 @@ int main(int argc, char** argv)
: 0;

// Display the size of the training text and the number of sequences
std::cout << "Training text size: " << full_tokens.size() << " characters\n";
std::cout << "Maximum number of sequences: " << max_sequences << "\n";
cout << "Training text size: " << full_tokens.size() << " characters\n";
cout << "Maximum number of sequences: " << max_sequences << "\n";

// Check if the text is too short
if (max_sequences == 0)
{
std::cerr << "ERROR: The Shakespeare text is too short for training. It must contain at least "
cerr << "ERROR: The Shakespeare text is too short for training. It must contain at least "
<< (max_seq_len + 1) << " characters.\n";
return 0;
}
Expand All @@ -203,7 +203,7 @@ int main(int argc, char** argv)
// Shuffle samples and labels if the --shuffle option is enabled
if (parser.option("shuffle"))
{
std::cout << "Shuffling training sequences and labels...\n";
cout << "Shuffling training sequences and labels...\n";
shuffle_samples_and_labels(samples, labels);
}

Expand Down Expand Up @@ -232,41 +232,41 @@ int main(int argc, char** argv)
if (predicted[i] == labels[i])
correct++;
double accuracy = (double)correct / labels.size();
std::cout << "Training accuracy (on this sample set): " << accuracy << "\n";
cout << "Training accuracy (on this sample set): " << accuracy << "\n";

// 7) Save the model
net.clean();
serialize(model_file) << net;
std::cout << "Model saved to " << model_file << "\n";
cout << "Model saved to " << model_file << "\n";
}

// ----------------------------------------------------------------------------------------
// Generate mode
// ----------------------------------------------------------------------------------------
if (parser.option("generate"))
{
std::cout << "=== GENERATE MODE ===\n";
cout << "=== GENERATE MODE ===\n";
// 1) Load the trained model
using net_infer = my_transformer_cfg::network_type<false>;
net_infer net;
if (file_exists(model_file))
{
deserialize(model_file) >> net;
std::cout << "Loaded model from " << model_file << "\n";
cout << "Loaded model from " << model_file << "\n";
}
else
{
std::cerr << "Error: model file not found. Please run --train first.\n";
cerr << "Error: model file not found. Please run --train first.\n";
return 0;
}
std::cout << my_transformer_cfg::model_info::describe() << std::endl;
std::cout << "Model parameters: " << count_parameters(net) << std::endl << std::endl;
cout << my_transformer_cfg::model_info::describe() << endl;
cout << "Model parameters: " << count_parameters(net) << endl << endl;

// 2) Get the prompt from the included slm_data.h
std::string prompt_text = shakespeare_prompt;
if (prompt_text.empty())
{
std::cerr << "No prompt found in slm_data.h.\n";
cerr << "No prompt found in slm_data.h.\n";
return 0;
}
// If prompt is longer than max_seq_len, we keep only the first window
Expand All @@ -287,7 +287,7 @@ int main(int argc, char** argv)
input_seq(i, 0) = PAD_TOKEN;
}

std::cout << "\nInitial prompt:\n" << prompt_text << " (...)\n\n\nGenerated text:\n" << prompt_text;
cout << "\nInitial prompt:\n" << prompt_text << " (...)\n\n\nGenerated text:\n" << prompt_text;

// 3) Generate new text
// We'll predict one character at a time, then shift the window
Expand All @@ -296,22 +296,22 @@ int main(int argc, char** argv)
const int next_char = net(input_seq); // single inference

// Print the generated character
std::cout << static_cast<char>(std::min(next_char, MAX_TOKEN_ID)) << std::flush;
cout << static_cast<char>(std::min(next_char, MAX_TOKEN_ID)) << flush;

// Shift left by 1
for (long i = 0; i < max_seq_len - 1; ++i)
input_seq(i, 0) = input_seq(i + 1, 0);
input_seq(max_seq_len - 1, 0) = std::min(next_char, MAX_TOKEN_ID);
}

std::cout << "\n\n(end of generation)\n";
cout << "\n\n(end of generation)\n";
}

return 0;
}
catch (std::exception& e)
catch (exception& e)
{
std::cerr << "Exception thrown: " << e.what() << std::endl;
cerr << "Exception thrown: " << e.what() << endl;
return 1;
}
}
Expand Down

0 comments on commit 1e09fc1

Please sign in to comment.