-
Notifications
You must be signed in to change notification settings - Fork 0
/
AI.cpp
executable file
·142 lines (124 loc) · 4.33 KB
/
AI.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#include "AI.h"
#include "Utility.h"
#include <chrono>
#include <iostream>
#include <fstream>
#include <caffe/caffe.hpp>
#include <cuda_runtime.h>
using namespace Utility;
using namespace isis;
using namespace caffe;
AI::AI(unsigned int index, SyncQueue<cv::Mat>& Q)
: Application ("AI_"+std::to_string(index))
, index_ (index)
, imgQueue (Q)
, isLearning_ (false)
{
Caffe::set_mode(Caffe::GPU);
google::SetCommandLineOption("minloglevel", "2");//error for caffe
create_directory(IMG_DIR + "/" + PREDICT_DIR);
set_events();
}
AI::~AI() {
}
void AI::train(const std::string& proto_filename){
isLearning_ = true;
//start our training
SolverParameter solver_param;
ReadProtoFromTextFile(proto_filename, &solver_param);
SGDSolver<float> solver(solver_param);
solver.Solve();
//cleanup
log()->info("Optimization complete...");
raiseEvent("StackedDAETrainingComplete", true);
isLearning_ = false;
}
const std::string AI::indexToStr(const std::string& file_list, uint64_t line_number){
std::ifstream infile(file_list);
std::string line;
uint64_t cur_line = 0;
while (std::getline(infile, line)){
if(cur_line++ == line_number) break;
}
infile.close();
return line;
}
void AI::predict(const std::string& network_proto){
auto net = load_network(network_proto
, MODEL_DIR + "/" + IMAGENET_MODEL); //XX move loading to imgNet ptr
float loss; // Run Forward pass
auto& result = net->ForwardPrefilled(&loss);
// Now result will contain the argmax.
const float* argmaxs = result[1]->cpu_data();
for (int i = 0; i < result[1]->num(); ++i) {
log()->info("[%s] predicts image %d as [%s]"
, net->name()
, i
, indexToStr(MODEL_DIR + "/" + SYNSET_FILE
, static_cast<uint64_t>(argmaxs[i])));
}
//clean up all the prediction files from this trial
removeAll(IMG_DIR + "/" + PREDICT_DIR);
}
Net<float>* AI::load_network(const std::string& proto_model
, const std::string& pretrained_model)
{
if(networks_.find(proto_model) == networks_.end())
{
log()->debug("Created new network from [%s | %s]", proto_model, pretrained_model);
networks_[proto_model] = new Net<float>(proto_model);
networks_[proto_model]->CopyTrainedLayersFrom(pretrained_model); //get trained net
}
return networks_[proto_model];
}
void AI::write_to_file(const cv::Mat &img
, const std::string& imagename
, const std::string& directory
, const std::string& filename
, uint64_t classif){
std::ofstream of(directory + "/" + filename, std::ios_base::app | std::ios_base::ate | std::ios_base::out);
std::string latestfile = directory + "/" + imagename + ".JPEG";
cv::imwrite(latestfile, img);
of<<latestfile<<" "<<classif<<std::endl;
of.close();
}
void AI::normalize(cv::Mat& img){
cv::subtract(img, cv::mean(img), img);
//TODO: Perhaps std dev normalize?
}
void AI::pull_img_queue() {
imgQueue.TryDequeue(latestImg_);
normalize(latestImg_);
if(!isLearning_){
if(predictDataSetSize_++ < BUFFER_SIZE){
create_directory(IMG_DIR + "/" + PREDICT_DIR);
write_to_file(latestImg_
, std::to_string(predictDataSetSize_)
, IMG_DIR + "/" + PREDICT_DIR
, PREDICT_LIST_FILE
, 0);
}else{
predict(MODEL_DIR + "/" + PROTO_FILE);
predictDataSetSize_ = 0; //reset
}
}else{
if(currentDataSetSize_++ < DATA_SET_SIZE){
(currentDataSetSize_ > 0.8f*DATA_SET_SIZE)? write_to_file(latestImg_, std::to_string((uint64_t)getTimeSinceEpochMS()), IMG_DIR, TESTING_LIST_FILE, 0)
: write_to_file(latestImg_, std::to_string((uint64_t)getTimeSinceEpochMS()), IMG_DIR, TRAINING_LIST_FILE, 0);
}else{
train(PROTO_FILE_SOLV);
}
}
}
void AI::set_events(){
addSignal("StackedDAETrainingComplete");
}
void AI::Run() {
while (Isis::isRunning) {
try {
pull_img_queue();
} catch (std::exception& e) {
log()->error("Exception in Run: %s",e.what());
}
}
}