-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMCTS.py
346 lines (269 loc) · 13.9 KB
/
MCTS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
from node import AlphaZeroNode, Node
from games import Game
import numpy as np
import torch
import torch.nn as nn
from global_vars import device
class MCTS:
def __init__(self, game: Game, args: dict):
"""
A standard (non-AlphaZero) Monte Carlo Tree Search implementation.
:param game: An instance of the Game class providing game-specific logic.
:param args: A dictionary containing configuration parameters (e.g., num_searches).
"""
self.game = game
self.args = args
def search(self, state):
"""
Perform MCTS searches starting from the given state.
:param state: The initial state from which to begin the MCTS.
:return: A probability distribution (action_probs) over all possible actions.
"""
# 1. Define the root node for the current search
root = Node(self.game, self.args, state)
# 2. Run several search iterations (num_searches controls how many times we traverse the tree)
for search in range(self.args["num_searches"]):
node = root
# --- SELECTION phase ---
# Move down the tree along the most promising path (based on UCB or similar).
while node.is_fully_expanded():
node = node.select()
# Evaluate if the state is terminal (win/loss/draw) and retrieve its value.
value, terminated = self.game.get_value_and_terminated(node.state, node.action_taken)
# Adjust the value for the opponent, if needed (e.g., if it's the opponent's perspective).
value = self.game.get_opponent_value(value)
# --- EXPANSION & SIMULATION phase ---
if not terminated:
# Expand the node by adding one or more children (possible next states).
node = node.expand()
# Simulate a rollout from this newly expanded node to estimate the outcome.
value = node.simulate()
# --- BACKPROPAGATION phase ---
# Propagate the simulation result (value) back up the tree, updating parent nodes.
node.backpropagate(value)
# Build a probability distribution (action_probs) over actions based on visit counts.
action_probs = np.zeros(self.game.action_size)
for child in root.children:
action_probs[child.action_taken] = child.visit_count
# Normalize the probabilities across all actions.
action_probs /= np.sum(action_probs)
return action_probs
class AlphaZeroMCTS:
def __init__(self, game: Game, model: nn.Module, args: dict):
"""
An MCTS that uses a neural network (AlphaZero style) to guide tree exploration.
:param game: An instance of the Game class providing game-specific logic.
:param model: A PyTorch neural network model that outputs policy and value.
:param args: A dictionary containing configuration parameters (e.g., num_searches, epsilon, alpha).
"""
self.game = game
self.model = model
self.args = args
self.model.to(device)
@torch.no_grad()
def search(self, state):
"""
Execute MCTS searches guided by the neural network.
:param state: The initial state from which to begin the AlphaZero MCTS.
:return: A probability distribution (action_probs) over all possible actions.
"""
# 1. Create the root node and initialize its visit count.
root = AlphaZeroNode(self.game, self.args, state, visit_count=1)
# Get the initial policy (neural network prediction for the root state).
policy, _ = self.model(self.state_to_tensor(state))
# Convert model outputs (logits) to probabilities.
policy = self.policy_to_prob(policy)
# Optionally apply Dirichlet noise to encourage exploration, and mask invalid moves.
policy = self.mask_invalid_moves(policy, state, noise=True)
# Expand the root node with the policy distribution over its children.
root.expand(policy)
# 2. Run multiple searches (each search is a traversal from root to leaf).
for search in range(self.args['num_searches']):
node = root
# --- SELECTION phase ---
while node.is_fully_expanded():
node = node.select()
# Compute the value of the state and check if it’s terminal.
value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
# Flip the value perspective if needed (depending on whose turn it is).
value = self.game.get_opponent_value(value)
# --- EXPANSION phase ---
if not is_terminal:
tensor_state = self.state_to_tensor(node.state)
policy, value = self.model(tensor_state)
# Convert model logits to probability distribution for next moves.
policy = self.policy_to_prob(policy)
# Mask out invalid actions.
policy = self.mask_invalid_moves(policy, node.state)
# Convert the NN value output to a scalar.
value = self.value_to_scalar(value)
# Expand the node with valid moves from the policy distribution.
node.expand(policy)
# --- BACKPROPAGATION phase ---
node.backpropagate(value)
# Convert the root children’s visit counts into a normalized probability distribution.
action_probs = np.zeros(self.game.action_size)
for child in root.children:
action_probs[child.action_taken] = child.visit_count
action_probs /= np.sum(action_probs)
return action_probs
def state_to_tensor(self, state):
"""
Encodes the state into a suitable tensor format for the neural network.
:param state: The current game state.
:return: A torch.Tensor containing the encoded state on the correct device (CPU/GPU).
"""
encoded_state = self.game.encode_state(state)
tensor_state = torch.tensor(encoded_state).unsqueeze(0)
return tensor_state.to(device)
def policy_to_prob(self, policy):
"""
Converts raw policy logits from the network to a probability distribution.
:param policy: Raw network output (logits).
:return: A numpy array of normalized policy values.
"""
policy = torch.softmax(policy, dim=1).squeeze(0).cpu().detach().numpy()
return policy
def value_to_scalar(self, value):
"""
Converts the network’s value output (tensor) to a Python scalar.
:param value: The network’s value output.
:return: A floating-point scalar value.
"""
return value.item()
def mask_invalid_moves(self, policy, state, noise=False):
"""
Masks out invalid moves from the policy distribution, optionally adding Dirichlet noise.
:param policy: The policy distribution over actions from the neural network.
:param state: The current game state.
:param noise: Whether to add Dirichlet noise to encourage exploration.
:return: A new policy array with invalid moves zeroed out and renormalized.
"""
valid_moves = self.game.get_valid_moves(state)
# If noise is enabled, blend the original policy distribution with a Dirichlet sample.
if noise:
epsilon = self.args.get("epsilon", 0.001)
alpha = self.args.get("alpha", 0.3)
policy = (1 - epsilon) * policy + epsilon * np.random.dirichlet(alpha=alpha * np.ones(self.game.action_size))
# Zero out invalid moves and re-normalize the distribution.
policy *= valid_moves
policy /= np.sum(policy)
return policy
class AlphaZeroParallelMCTS:
def __init__(self, game: Game, model: nn.Module, args: dict):
"""
A parallelized AlphaZero MCTS, handling multiple states in batch.
:param game: An instance of the Game class for game-specific logic.
:param model: A PyTorch neural network model that outputs policy and value.
:param args: A dictionary with configuration parameters (e.g., num_searches, epsilon, alpha).
"""
self.game = game
self.model = model
self.args = args
self.model.to(device)
@torch.no_grad()
def search(self, states, spgs):
"""
Perform parallel MCTS searches for a batch of states (states) and
their corresponding search process group objects (spgs).
:param states: A list or batch of states to run MCTS on.
:param spgs: A list of objects that store MCTS nodes and results for each state.
"""
# 1. Get the initial policy for all states in a single forward pass.
policy, _ = self.model(self.state_to_tensor(states))
policy = self.policy_to_prob(policy)
# Add Dirichlet noise once for each state in the batch to encourage exploration.
policy = self.add_noise(policy)
# 2. For each state in the batch, create and expand its root node.
for i, spg in enumerate(spgs):
spg_policy = policy[i]
spg_policy = self.mask_invalid_moves(spg_policy, states[i])
spg.root = AlphaZeroNode(self.game, self.args, states[i], visit_count=1)
spg.root.expand(spg_policy)
# 3. Perform the configured number of MCTS searches.
for search in range(self.args['num_searches']):
# For each state, select a path down the tree.
for i, spg in enumerate(spgs):
spg.node = None
node = spg.root
# --- SELECTION: Follow the best child until a leaf or unexpanded node is found ---
while node.is_fully_expanded():
node = node.select()
# Evaluate if the state is terminal and retrieve its value.
value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
value = self.game.get_opponent_value(value)
# If the node’s state is terminal, backpropagate the value; otherwise, mark for expansion.
if is_terminal:
node.backpropagate(value)
else:
spg.node = node # This node will be expanded after the NN forward pass.
# 4. Gather all the nodes that need expansion in a batch for a single forward pass.
expandable_spgs = [mapping_idx for mapping_idx, spg in enumerate(spgs) if spg.node is not None]
# If there are nodes to expand, run them all in a single batch forward pass.
if len(expandable_spgs) > 0:
# Collect the states that need expansion into a batch.
states = np.stack([spgs[mapping_idx].node.state for mapping_idx in expandable_spgs])
tensor_states = self.state_to_tensor(states)
policy, value = self.model(tensor_states)
# Convert model outputs to probabilities and keep the raw value for backprop.
policy = self.policy_to_prob(policy)
# 5. Expand each node in the parallel batch and backpropagate the evaluated value.
for i, mapping_idx in enumerate(expandable_spgs):
node = spgs[mapping_idx].node
spg_policy = policy[i]
spg_value = value[i]
# Mask invalid moves from the policy distribution.
spg_policy = self.mask_invalid_moves(spg_policy, node.state)
# Expand the node with the valid policy distribution.
node.expand(spg_policy)
# Backpropagate the value up the tree.
node.backpropagate(spg_value)
def state_to_tensor(self, state):
"""
Encodes one or multiple states into a tensor for batch evaluation.
:param state: A single state or a batch of states.
:return: A torch.Tensor ready to be processed by the model.
"""
encoded_state = self.game.encode_state(state)
tensor_state = torch.tensor(encoded_state)
return tensor_state.to(device)
def policy_to_prob(self, policy):
"""
Transforms raw policy logits into a probability distribution for each state in the batch.
:param policy: Raw network output (logits) with shape [batch_size, action_size].
:return: A numpy array of shape [batch_size, action_size], normalized per row.
"""
policy = torch.softmax(policy, dim=1).cpu().detach().numpy()
return policy
def value_to_scalar(self, value):
"""
Converts the network’s value output (tensor) to a Python scalar.
:param value: The network's value output.
:return: A floating-point scalar value.
"""
return value.item()
def add_noise(self, policy):
"""
Adds Dirichlet noise to the policy distribution to encourage exploration in a batch context.
:param policy: A [batch_size, action_size] array of policy distributions.
:return: A new policy array with applied noise.
"""
epsilon = self.args.get("epsilon", 0.001)
alpha = self.args.get("alpha", 0.3)
# Apply noise for each row in the batch distribution.
policy = (1 - epsilon) * policy + epsilon * np.random.dirichlet(
alpha=alpha * np.ones(self.game.action_size),
size=policy.shape[0]
)
return policy
def mask_invalid_moves(self, policy, state):
"""
Zeroes out invalid moves and renormalizes the distribution for the given state.
:param policy: The policy distribution for a specific state (single row).
:param state: The current game state.
:return: The updated policy distribution where invalid moves are removed.
"""
valid_moves = self.game.get_valid_moves(state)
policy *= valid_moves
policy /= np.sum(policy)
return policy