Skip to content

Commit

Permalink
command line arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
jcuberdruid committed Nov 5, 2023
1 parent 29b4de2 commit dae6bd9
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
CXX = clang++
CXX = g++
CXXFLAGS = -Ofast -std=c++20
TARGET = runmain.out
SRC = main.cpp
Expand Down
50 changes: 40 additions & 10 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,51 @@

using namespace std;

int main()
{

#include <iostream>
#include <vector>
#include <string>
#include <cstdlib>
#include <cstring>

using namespace std;

// Assuming the Model, CallBackFunctionType, load_mnist_csv and one_hot_encode
// are defined somewhere in the program.

int main(int argc, char* argv[])
{
cout << "#################################################################" << endl;
cout << "Generating Model" << endl;
cout << "#################################################################" << endl;
// Model

// Default values
double learningRate = 0.00015666;
string optimizerType = "adam"; // supported: sgd, sgd_momentum, adagrad, rmsprop, adam
string runNote = "tanh_256_adam";
string optimizerType = "sgd"; // supported: sgd, sgd_momentum, adagrad, rmsprop, adam
string activationFunction = "sigmoid";
string runNote = "sigmoid_256_sgd";

// Parse command line arguments
for(int i = 1; i < argc; i++) {
if(strcmp(argv[i], "-lr") == 0 && i + 1 < argc) {
learningRate = atof(argv[++i]);
} else if(strcmp(argv[i], "-opt") == 0 && i + 1 < argc) {
optimizerType = argv[++i];
} else if(strcmp(argv[i], "-act") == 0 && i + 1 < argc) {
activationFunction = argv[++i];
} else if(strcmp(argv[i], "-note") == 0 && i + 1 < argc) {
runNote = argv[++i];
} else {
cerr << "Usage: " << argv[0] << " [-lr learning_rate] [-opt optimizer] [-act activation_function] [-note run_note]" << endl;
return 1;
}
}

Model test("cross_entropy", optimizerType, learningRate, runNote);

vector<CallBackFunctionType> callbacks;
test.addLayer("tanh", make_tuple(256, 784), callbacks);
test.addLayer("tanh", make_tuple(10, 256), callbacks);
test.addLayer(activationFunction, make_tuple(256, 784), callbacks);
test.addLayer(activationFunction, make_tuple(10, 256), callbacks);
test.infoLayers();

cout << "#################################################################" << endl;
Expand All @@ -40,10 +70,10 @@ int main()
vector<vector<int>> label_vec = one_hot_encode(labelNums, 10);

// subset data for faster testing:
label_vec.resize(5000);
images.resize(5000);
//label_vec.resize(5000);
//images.resize(5000);

test.teach(label_vec, images, 1);
test.teach(label_vec, images, 100);

cout << "train accuracy " << test.getLastAccuracy() << endl;

Expand Down

0 comments on commit dae6bd9

Please sign in to comment.