-
Notifications
You must be signed in to change notification settings - Fork 49
/
Copy pathtop_k.py
366 lines (295 loc) · 13.3 KB
/
top_k.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
"""
Implements the SAE training scheme from https://arxiv.org/abs/2406.04093.
Significant portions of this code have been copied from https://github.com/EleutherAI/sae/blob/main/sae
"""
import einops
import torch as t
import torch.nn as nn
from collections import namedtuple
from typing import Optional
from ..config import DEBUG
from ..dictionary import Dictionary
from ..trainers.trainer import (
SAETrainer,
get_lr_schedule,
set_decoder_norm_to_unit_norm,
remove_gradient_parallel_to_decoder_directions,
)
@t.no_grad()
def geometric_median(points: t.Tensor, max_iter: int = 100, tol: float = 1e-5):
"""Compute the geometric median `points`. Used for initializing decoder bias."""
# Initialize our guess as the mean of the points
guess = points.mean(dim=0)
prev = t.zeros_like(guess)
# Weights for iteratively reweighted least squares
weights = t.ones(len(points), device=points.device)
for _ in range(max_iter):
prev = guess
# Compute the weights
weights = 1 / t.norm(points - guess, dim=1)
# Normalize the weights
weights /= weights.sum()
# Compute the new geometric median
guess = (weights.unsqueeze(1) * points).sum(dim=0)
# Early stopping condition
if t.norm(guess - prev) < tol:
break
return guess
class AutoEncoderTopK(Dictionary, nn.Module):
"""
The top-k autoencoder architecture and initialization used in https://arxiv.org/abs/2406.04093
NOTE: (From Adam Karvonen) There is an unmaintained implementation using Triton kernels in the topk-triton-implementation branch.
We abandoned it as we didn't notice a significant speedup and it added complications, which are noted
in the AutoEncoderTopK class docstring in that branch.
With some additional effort, you can train a Top-K SAE with the Triton kernels and modify the state dict for compatibility with this class.
Notably, the Triton kernels currently have the decoder to be stored in nn.Parameter, not nn.Linear, and the decoder weights must also
be stored in the same shape as the encoder.
"""
def __init__(self, activation_dim: int, dict_size: int, k: int):
super().__init__()
self.activation_dim = activation_dim
self.dict_size = dict_size
assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer"
self.register_buffer("k", t.tensor(k, dtype=t.int))
self.register_buffer("threshold", t.tensor(-1.0, dtype=t.float32))
self.decoder = nn.Linear(dict_size, activation_dim, bias=False)
self.decoder.weight.data = set_decoder_norm_to_unit_norm(
self.decoder.weight, activation_dim, dict_size
)
self.encoder = nn.Linear(activation_dim, dict_size)
self.encoder.weight.data = self.decoder.weight.T.clone()
self.encoder.bias.data.zero_()
self.b_dec = nn.Parameter(t.zeros(activation_dim))
def encode(self, x: t.Tensor, return_topk: bool = False, use_threshold: bool = False):
post_relu_feat_acts_BF = nn.functional.relu(self.encoder(x - self.b_dec))
if use_threshold:
encoded_acts_BF = post_relu_feat_acts_BF * (post_relu_feat_acts_BF > self.threshold)
if return_topk:
post_topk = post_relu_feat_acts_BF.topk(self.k, sorted=False, dim=-1)
return encoded_acts_BF, post_topk.values, post_topk.indices, post_relu_feat_acts_BF
else:
return encoded_acts_BF
post_topk = post_relu_feat_acts_BF.topk(self.k, sorted=False, dim=-1)
# We can't split immediately due to nnsight
tops_acts_BK = post_topk.values
top_indices_BK = post_topk.indices
buffer_BF = t.zeros_like(post_relu_feat_acts_BF)
encoded_acts_BF = buffer_BF.scatter_(dim=-1, index=top_indices_BK, src=tops_acts_BK)
if return_topk:
return encoded_acts_BF, tops_acts_BK, top_indices_BK, post_relu_feat_acts_BF
else:
return encoded_acts_BF
def decode(self, x: t.Tensor) -> t.Tensor:
return self.decoder(x) + self.b_dec
def forward(self, x: t.Tensor, output_features: bool = False):
encoded_acts_BF = self.encode(x)
x_hat_BD = self.decode(encoded_acts_BF)
if not output_features:
return x_hat_BD
else:
return x_hat_BD, encoded_acts_BF
def scale_biases(self, scale: float):
self.encoder.bias.data *= scale
self.b_dec.data *= scale
if self.threshold >= 0:
self.threshold *= scale
def from_pretrained(path, k: Optional[int] = None, device=None):
"""
Load a pretrained autoencoder from a file.
"""
state_dict = t.load(path)
dict_size, activation_dim = state_dict["encoder.weight"].shape
if k is None:
k = state_dict["k"].item()
elif "k" in state_dict and k != state_dict["k"].item():
raise ValueError(f"k={k} != {state_dict['k'].item()}=state_dict['k']")
autoencoder = AutoEncoderTopK(activation_dim, dict_size, k)
autoencoder.load_state_dict(state_dict)
if device is not None:
autoencoder.to(device)
return autoencoder
class TopKTrainer(SAETrainer):
"""
Top-K SAE training scheme.
"""
def __init__(
self,
steps: int, # total number of steps to train for
activation_dim: int,
dict_size: int,
k: int,
layer: int,
lm_name: str,
dict_class: type = AutoEncoderTopK,
lr: Optional[float] = None,
auxk_alpha: float = 1 / 32, # see Appendix A.2
warmup_steps: int = 1000,
decay_start: Optional[int] = None, # when does the lr decay start
threshold_beta: float = 0.999,
threshold_start_step: int = 1000,
seed: Optional[int] = None,
device: Optional[str] = None,
wandb_name: str = "AutoEncoderTopK",
submodule_name: Optional[str] = None,
):
super().__init__(seed)
assert layer is not None and lm_name is not None
self.layer = layer
self.lm_name = lm_name
self.submodule_name = submodule_name
self.wandb_name = wandb_name
self.steps = steps
self.decay_start = decay_start
self.warmup_steps = warmup_steps
self.k = k
self.threshold_beta = threshold_beta
self.threshold_start_step = threshold_start_step
if seed is not None:
t.manual_seed(seed)
t.cuda.manual_seed_all(seed)
# Initialise autoencoder
self.ae = dict_class(activation_dim, dict_size, k)
if device is None:
self.device = "cuda" if t.cuda.is_available() else "cpu"
else:
self.device = device
self.ae.to(self.device)
if lr is not None:
self.lr = lr
else:
# Auto-select LR using 1 / sqrt(d) scaling law from Figure 3 of the paper
scale = dict_size / (2**14)
self.lr = 2e-4 / scale**0.5
self.auxk_alpha = auxk_alpha
self.dead_feature_threshold = 10_000_000
self.top_k_aux = activation_dim // 2 # Heuristic from B.1 of the paper
self.num_tokens_since_fired = t.zeros(dict_size, dtype=t.long, device=device)
self.logging_parameters = ["effective_l0", "dead_features", "pre_norm_auxk_loss"]
self.effective_l0 = -1
self.dead_features = -1
self.pre_norm_auxk_loss = -1
# Optimizer and scheduler
self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999))
lr_fn = get_lr_schedule(steps, warmup_steps, decay_start=decay_start)
self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn)
def get_auxiliary_loss(self, residual_BD: t.Tensor, post_relu_acts_BF: t.Tensor):
dead_features = self.num_tokens_since_fired >= self.dead_feature_threshold
self.dead_features = int(dead_features.sum())
if self.dead_features > 0:
k_aux = min(self.top_k_aux, self.dead_features)
auxk_latents = t.where(dead_features[None], post_relu_acts_BF, -t.inf)
# Top-k dead latents
auxk_acts, auxk_indices = auxk_latents.topk(k_aux, sorted=False)
auxk_buffer_BF = t.zeros_like(post_relu_acts_BF)
auxk_acts_BF = auxk_buffer_BF.scatter_(dim=-1, index=auxk_indices, src=auxk_acts)
# Note: decoder(), not decode(), as we don't want to apply the bias
x_reconstruct_aux = self.ae.decoder(auxk_acts_BF)
l2_loss_aux = (
(residual_BD.float() - x_reconstruct_aux.float()).pow(2).sum(dim=-1).mean()
)
self.pre_norm_auxk_loss = l2_loss_aux
# normalization from OpenAI implementation: https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/kernels.py#L614
residual_mu = residual_BD.mean(dim=0)[None, :].broadcast_to(residual_BD.shape)
loss_denom = (residual_BD.float() - residual_mu.float()).pow(2).sum(dim=-1).mean()
normalized_auxk_loss = l2_loss_aux / loss_denom
return normalized_auxk_loss.nan_to_num(0.0)
else:
self.pre_norm_auxk_loss = -1
return t.tensor(0, dtype=residual_BD.dtype, device=residual_BD.device)
def update_threshold(self, top_acts_BK: t.Tensor):
device_type = "cuda" if top_acts_BK.is_cuda else "cpu"
with t.autocast(device_type=device_type, enabled=False), t.no_grad():
active = top_acts_BK.clone().detach()
active[active <= 0] = float("inf")
min_activations = active.min(dim=1).values.to(dtype=t.float32)
min_activation = min_activations.mean()
B, K = active.shape
assert len(active.shape) == 2
assert min_activations.shape == (B,)
if self.ae.threshold < 0:
self.ae.threshold = min_activation
else:
self.ae.threshold = (self.threshold_beta * self.ae.threshold) + (
(1 - self.threshold_beta) * min_activation
)
def loss(self, x, step=None, logging=False):
# Run the SAE
f, top_acts_BK, top_indices_BK, post_relu_acts_BF = self.ae.encode(
x, return_topk=True, use_threshold=False
)
if step > self.threshold_start_step:
self.update_threshold(top_acts_BK)
x_hat = self.ae.decode(f)
# Measure goodness of reconstruction
e = x - x_hat
# Update the effective L0 (again, should just be K)
self.effective_l0 = top_acts_BK.size(1)
# Update "number of tokens since fired" for each features
num_tokens_in_step = x.size(0)
did_fire = t.zeros_like(self.num_tokens_since_fired, dtype=t.bool)
did_fire[top_indices_BK.flatten()] = True
self.num_tokens_since_fired += num_tokens_in_step
self.num_tokens_since_fired[did_fire] = 0
l2_loss = e.pow(2).sum(dim=-1).mean()
auxk_loss = (
self.get_auxiliary_loss(e.detach(), post_relu_acts_BF) if self.auxk_alpha > 0 else 0
)
loss = l2_loss + self.auxk_alpha * auxk_loss
if not logging:
return loss
else:
return namedtuple("LossLog", ["x", "x_hat", "f", "losses"])(
x,
x_hat,
f,
{"l2_loss": l2_loss.item(), "auxk_loss": auxk_loss.item(), "loss": loss.item()},
)
def update(self, step, x):
# Initialise the decoder bias
if step == 0:
median = geometric_median(x)
median = median.to(self.ae.b_dec.dtype)
self.ae.b_dec.data = median
# compute the loss
x = x.to(self.device)
loss = self.loss(x, step=step)
loss.backward()
# clip grad norm and remove grads parallel to decoder directions
self.ae.decoder.weight.grad = remove_gradient_parallel_to_decoder_directions(
self.ae.decoder.weight,
self.ae.decoder.weight.grad,
self.ae.activation_dim,
self.ae.dict_size,
)
t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0)
# do a training step
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()
# Make sure the decoder is still unit-norm
self.ae.decoder.weight.data = set_decoder_norm_to_unit_norm(
self.ae.decoder.weight, self.ae.activation_dim, self.ae.dict_size
)
return loss.item()
@property
def config(self):
return {
"trainer_class": "TopKTrainer",
"dict_class": "AutoEncoderTopK",
"lr": self.lr,
"steps": self.steps,
"auxk_alpha": self.auxk_alpha,
"warmup_steps": self.warmup_steps,
"decay_start": self.decay_start,
"threshold_beta": self.threshold_beta,
"threshold_start_step": self.threshold_start_step,
"seed": self.seed,
"activation_dim": self.ae.activation_dim,
"dict_size": self.ae.dict_size,
"k": self.ae.k.item(),
"device": self.device,
"layer": self.layer,
"lm_name": self.lm_name,
"wandb_name": self.wandb_name,
"submodule_name": self.submodule_name,
}