-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathA3CModel.h
153 lines (119 loc) · 4.61 KB
/
A3CModel.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#ifndef A3C_MODEL_H
#define A3C_MODEL_H
#include "RLModel.h"
#include <memory>
#include <vector>
#include <deque>
#include <mutex>
#include <random>
#include <cmath>
#include <thread>
#include <atomic>
#include <fstream>
#include "hyperparameter_tuner.h"
#include "actor_critic.h"
#include "Globals.h"
// Forward-deklaraatiot
class A3CModel;
class A3CNetwork;
class A3CWorker;
// A3C-työntekijäluokka, joka suorittaa oppimisen yhdessä säikeessä
class A3CWorker {
private:
int workerId;
ActorCritic localNetwork;
A3CNetwork& globalNetwork;
// Gradienttien keräämiseen
std::vector<float> states;
std::vector<float> actions;
std::vector<float> rewards;
std::vector<float> values;
// Hyperparametrit
float gamma;
public:
A3CWorker(int id, A3CNetwork& global, float g = 0.99f, int updateFreq = 20);
// Suorita eteenpäinkulku paikallisella verkolla
void forward(float state[STATE_SIZE], float action[ACTION_SIZE], float& value);
// Kerää kokemus gradienttien laskemista varten
void collectExperience(float state[STATE_SIZE], float action[ACTION_SIZE],
float reward, float value);
// Laske gradientit ja päivitä globaali verkko
void updateGlobalNetwork(float nextStateValue, bool isTerminal);
// Tarkista pitäisikö päivittää globaali verkko
bool shouldUpdate() const;
// Palauta työntekijän ID
int getId() const;
// Palauta paikallinen verkko
ActorCritic& getLocalNetwork() const;
// Aseta päivitystaajuus
void setUpdateFrequency(int freq);
// Palauta rewards-vektorin koko
size_t getRewardsSize() const;
// Aseta hyperparametrit
void setHyperParams(const HyperParameters& params);
};
// A3C (Asynchronous Advantage Actor-Critic) verkko
class A3CNetwork {
private:
// Globaali verkko, jota kaikki säikeet päivittävät
ActorCritic globalNetwork;
// Mutex globaalin verkon päivityksille
std::mutex globalNetworkMutex;
// Hyperparametrit
float learningRate;
float gamma;
float tau; // Target-verkkojen päivitysnopeus
// Tilastot
std::atomic<int> totalUpdates{0};
std::atomic<float> averageReward{0.0f};
public:
A3CNetwork(float lr = 0.0001f, float g = 0.99f, float t = 0.001f);
// Alusta työntekijäverkko globaalin verkon painoilla
void initializeWorker(ActorCritic& worker);
// Kopioi verkon painot lähteestä kohteeseen
void copyNetworkWeights(const ActorCritic& source, ActorCritic& target);
// Päivitä globaali verkko työntekijän gradienteilla
void updateGlobalNetwork(ActorCritic& worker);
// Päivitä target-verkot soft-update-menetelmällä
void updateTargetNetworks();
// Hae globaalin verkon painot työntekijälle
void pullGlobalNetworkWeights(ActorCritic& worker);
// Tallenna globaalin verkon painot tiedostoon
void saveGlobalWeights(const std::string& filename);
// Lataa globaalin verkon painot tiedostosta
void loadGlobalWeights(const std::string& filename);
// Päivitä oppimisnopeutta
void updateLearningRate(float newLR);
// Päivitä keskimääräistä palkkiota
void updateAverageReward(float reward);
// Getterit
float getAverageReward() const;
int getTotalUpdates() const;
float getLearningRate() const;
};
// A3C-mallin toteutus RLModel-rajapinnan kautta
class A3CModel : public RLModel {
private:
A3CWorker& worker;
public:
A3CModel(A3CWorker& w);
void forward(float state[/*STATE_SIZE*/], float action[/*ACTION_SIZE*/], float& value) override;
void collectExperience(float state[/*STATE_SIZE*/], float action[/*ACTION_SIZE*/],
float reward, float value) override;
void update(float nextStateValue, bool isTerminal) override;
bool shouldUpdate() const override;
size_t getRewardsSize() const override;
std::string getName() const override;
float getLastAdvantage() const override;
float getLastWeightUpdate() const override;
float getTDError() const override;
int getUpdateCounter() const override;
void setUpdateFrequency(int freq) override;
const HyperParameters& getHyperParams() const override;
bool saveModel(const std::string& filename) const override;
bool loadModel(const std::string& filename) override;
void setHyperParams(const HyperParameters& params) override;
// Staattinen metodi, joka luo ja lataa A3C-mallin
static std::unique_ptr<RLModel> createAndLoad(A3CWorker& worker, const std::string& filename);
};
#endif // A3C_MODEL_H