-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainCNN2dManualParams.m
56 lines (43 loc) · 1.62 KB
/
trainCNN2dManualParams.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
function trainCNN2dManualParams(classifierType,params,opts)
% trainCNN2dManualParams train a 2D CNN
%
% trainCNN2dManualParams(classifierType,params) trains a 2D CNN.
% classifierType is a handle to the 2D CNN class, i.e. @CNN2D. params is a
% struct containing fields for the CNN's hyperparameters.
%
% Name-Value arguments:
% UseParallel - Use the parallel computing toolbox
% UseGPU - Use a GPU for training
% SPDX-License-Identifier: BSD-3-Clause
arguments
classifierType function_handle
params.ClassifierParams = {}
opts.UseParallel = false
opts.UseGPU = false
end
if opts.UseParallel
if isempty(gcp('nocreate'))
parpool();
end
end
% Set up data paths
beehiveDataSetup;
% Load in the training data
load(trainingDataDir + filesep + "trainingData","trainingData",...
"trainingImgLabels","trainingMetadata");
% Load in the validation data
load(validationDataDir + filesep + "validationData","validationData","validationImgLabels");
% Train and evaluate the classifier
[objective,~,userdata] = validationObjFcn(classifierType,trainingData,...
trainingImgLabels,validationData,validationImgLabels,...
UseParallel=opts.UseParallel,UseGPU=opts.UseGPU,...
Static=params.ClassifierParams);
disp("objective = " + objective);
% Save results so we can do parameter selection later
filename = userdata.Classifier.Name + "ManualParamTraining.mat";
if ~exist(trainingResultsDir + filesep + "default-params")
mkdir(trainingResultsDir,"default-params");
end
save(trainingResultsDir + filesep + "default-params" + filesep + filename,...
"objective","userdata",'-v7.3');
end