forked from rohanchandra30/TrackNPred
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprediction_cmd.py
33 lines (27 loc) · 844 Bytes
/
prediction_cmd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from model.model import TnpModel
args = {}
# ## path settings
# args["dir"] = str(self.view.dataDir.text())
# args["frames"] = str(self.view.framesDir.text())
# ## detection settings
# args["detection"] = str(self.view.detectionSelector.currentText())
# args["detConf"] = float(self.view.detConfidence.text())
# args["NMS"] = float(self.view.nmsInput.text())
# args["display"] = "False"
## prediction settings
args["predAlgo"] = "Traphic"
args["pretrainEpochs"] = 6
args["trainEpochs"] = 10
args["batch_size"] = 64
args["dropout"] = .5
args["optim"] = "Adam"
args["lr"] = .0001
args["cuda"] = True
args["maneuvers"] = False
args["modelLoc"] = "resources/trained_models/Traphic_model.tar"
args["pretrain_loss"] = ''
args['train_loss'] = "MSE"
args["dir"] = 'resources/data/TRAF'
model = TnpModel()
# model.train(args)
model.evaluate(args)