forked from googleprojectzero/functionsimsearch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsimhashtrainer.hpp
53 lines (46 loc) · 2.17 KB
/
simhashtrainer.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#ifndef SIMHASHTRAINER_HPP
#define SIMHASHTRAINER_HPP
#include <spii/function.h>
#include "util.hpp"
// The code expects the following files to be present inside the data directory
// (which is passed in as first argument):
//
// functions.txt - a text file formed by concatenating the output of the
// functionfingerprints tool in verbose mode. Each line starts
// with [file_id]:[address], and is followed by the various
// features that make up the function.
// attract.txt - a file with pairs of [file_id]:[address] [file_id]:[address]
// indicating which functions should be the same.
// repulse.txt - a file with pairs of [file_id]:[address] [file_id]:[address]
// indicating which functions should NOT be the same
bool LoadTrainingData(const std::string& directory,
std::vector<FunctionFeatures>* all_functions,
std::vector<FeatureHash>* all_features_vector,
std::vector<std::pair<uint32_t, uint32_t>>* attractionset,
std::vector<std::pair<uint32_t, uint32_t>>* repulsionset);
// Convenience function. Expects the data described in LoadTrainingData, outputs
// a file full of weights.
bool TrainSimHashFromDataDirectory(const std::string& directory, const
std::string& weights_filename);
class SimHashTrainer {
public:
SimHashTrainer(
const std::vector<FunctionFeatures>* all_functions,
const std::vector<FeatureHash>* all_features,
const std::vector<std::pair<uint32_t, uint32_t>>* attractionset,
const std::vector<std::pair<uint32_t, uint32_t>>* repulsionset);
void Train(std::vector<double>* weights);
private:
void AddPairLossTerm(const std::pair<uint32_t, uint32_t>& pair,
spii::Function* function,
const std::vector<FunctionFeatures>* all_functions,
const std::vector<FeatureHash>* all_features_vector,
std::vector<std::vector<double>>* weights,
uint32_t set_size,
bool attract);
const std::vector<FunctionFeatures>* all_functions_;
const std::vector<FeatureHash>* all_features_;
const std::vector<std::pair<uint32_t, uint32_t>>* attractionset_;
const std::vector<std::pair<uint32_t, uint32_t>>* repulsionset_;
};
#endif // SIMHASHTRAINER_HPP