Skip to content

Commit

Permalink
simplify code a little
Browse files Browse the repository at this point in the history
  • Loading branch information
davisking committed Jan 19, 2025
1 parent 3ab9801 commit 19121aa
Showing 1 changed file with 20 additions and 62 deletions.
82 changes: 20 additions & 62 deletions examples/slm_basic_train_ex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,17 @@
// ----------------------------------------------------------------------------------------

// We treat each character as a token ID in [0..255].
static const int MAX_TOKEN_ID = 255;
static const int PAD_TOKEN = 256; // an extra "pad" token if needed
const int MAX_TOKEN_ID = 255;
const int PAD_TOKEN = 256; // an extra "pad" token if needed

// For simplicity, we assume each line from shakespeare_text is appended, ignoring them.
static std::vector<int> char_based_tokenize(const std::string& text)
std::vector<int> char_based_tokenize(const std::string& text)
{
std::vector<int> tokens;
tokens.reserve(text.size());
for (unsigned char c : text)
for (const int c : text)
{
tokens.push_back(std::min<int>(c, MAX_TOKEN_ID));
tokens.push_back(std::min(c, MAX_TOKEN_ID));
}
return tokens;
}
Expand Down Expand Up @@ -108,44 +108,18 @@ int main(int argc, char** argv)

if (parser.number_of_arguments() == 0 && !parser.option("train") && !parser.option("generate"))
{
std::cout << "Usage:\n"
<< " --train : Train a small transformer model on the Shakespeare text\n"
<< " --generate : Generate text from a trained model using a prompt\n"
<< " --learning-rate <value> : Set the learning rate for training (default: 1e-4)\n"
<< " --batch-size <value> : Set the mini-batch size for training (default: 64)\n"
<< " --generation-length <value> : Set the length of generated text (default: 400)\n"
<< " --alpha <value> : Set the initial learning rate for Adam optimizer (default: 0.004)\n"
<< " --beta1 <value> : Set the decay rate for the first moment estimate (default: 0.9)\n"
<< " --beta2 <value> : Set the decay rate for the second moment estimate (default: 0.999)\n"
<< " --max-samples <value> : Set the maximum number of training samples (default: 50000)\n"
<< " --shuffle : Shuffle training sequences and labels before training (default: false)\n";
parser.print_options();
return 0;
}

// Default values
double learning_rate = 1e-4;
long batch_size = 64;
int generation_length = 400;
double alpha = 0.004; // Initial learning rate for Adam
double beta1 = 0.9; // Decay rate for the first moment estimate
double beta2 = 0.999; // Decay rate for the second moment estimate
size_t max_samples = 50000; // Default maximum number of training samples

// Override defaults if options are provided
if (parser.option("learning-rate"))
learning_rate = std::stod(parser.option("learning-rate").argument());
if (parser.option("batch-size"))
batch_size = std::stol(parser.option("batch-size").argument());
if (parser.option("generation-length"))
generation_length = std::stoi(parser.option("generation-length").argument());
if (parser.option("alpha"))
alpha = std::stod(parser.option("alpha").argument());
if (parser.option("beta1"))
beta1 = std::stod(parser.option("beta1").argument());
if (parser.option("beta2"))
beta2 = std::stod(parser.option("beta2").argument());
if (parser.option("max-samples"))
max_samples = std::stoul(parser.option("max-samples").argument());
const double learning_rate = get_option(parser, "learning-rate", 1e-4);
const long batch_size = get_option(parser, "batch-size", 64);
const int generation_length = get_option(parser, "generation-length", 400);
const double alpha = get_option(parser, "alpha", 0.004); // Initial learning rate for Adam
const double beta1 = get_option(parser, "beta1", 0.9); // Decay rate for the first moment estimate
const double beta2 = get_option(parser, "beta2", 0.999); // Decay rate for the second moment estimate
const size_t max_samples = get_option(parser, "max-samples",50000); // Default maximum number of training samples

// We define a minimal config for demonstration
const long vocab_size = 257; // 0..255 for chars + 1 pad token
Expand Down Expand Up @@ -297,7 +271,7 @@ int main(int argc, char** argv)
prompt_text.erase(prompt_text.begin() + max_seq_len, prompt_text.end());

// Convert prompt to a token sequence
auto prompt_tokens = char_based_tokenize(prompt_text);
const auto prompt_tokens = char_based_tokenize(prompt_text);

// Put into a dlib matrix
dlib::matrix<int, 0, 1> input_seq(max_seq_len, 1);
Expand All @@ -310,37 +284,21 @@ int main(int argc, char** argv)
input_seq(i, 0) = PAD_TOKEN;
}

std::cout << "Initial prompt:\n" << prompt_text << " (...)\n\nGenerated text:\n" << prompt_text;
std::cout << "\nInitial prompt:\n" << prompt_text << " (...)\n\n\nGenerated text:\n" << prompt_text;

// 3) Generate new text
// We'll predict one character at a time, then shift the window
// until the total length is at least generation_length and we encounter two newlines.
std::string generated_text = prompt_text;
bool stop_generation = false;

while (generated_text.size() < (size_t)generation_length || !stop_generation)
for (int i = 0; i < generation_length; ++i)
{
unsigned long next_char = net(input_seq); // single inference

// Append the generated character to the text
generated_text += (char)(std::min<unsigned long>(next_char, MAX_TOKEN_ID));
const int next_char = net(input_seq); // single inference

// Print the generated character
std::cout << (char)(std::min<unsigned long>(next_char, MAX_TOKEN_ID));
std::cout << static_cast<char>(std::min(next_char, MAX_TOKEN_ID)) << std::flush;

// Shift left by 1
for (long i = 0; i < max_seq_len - 1; ++i)
input_seq(i, 0) = input_seq(i + 1, 0);
input_seq(max_seq_len - 1, 0) = (int)std::min<unsigned long>(next_char, MAX_TOKEN_ID);

// Check if the last two characters are newlines
if (generated_text.size() >= 2 &&
generated_text[generated_text.size() - 1] == '\n' &&
generated_text[generated_text.size() - 2] == '\n')
{
// Stop generation if the minimum length is reached
if (generated_text.size() >= (size_t)generation_length) stop_generation = true;
}
input_seq(max_seq_len - 1, 0) = std::min(next_char, MAX_TOKEN_ID);
}

std::cout << "\n\n(end of generation)\n";
Expand Down Expand Up @@ -391,4 +349,4 @@ int main(int argc, char** argv)
* > QUEEN ELIZABETH:
* > I go. Write to me very shortly.
* > And you shall understand from me her mind.
*/
*/

0 comments on commit 19121aa

Please sign in to comment.