Skip to content

Commit

Permalink
Remove PCA reweighting
Browse files Browse the repository at this point in the history
  • Loading branch information
Niklas Böhm committed Dec 12, 2024
1 parent 8ca0742 commit f62b7b5
Showing 1 changed file with 0 additions and 108 deletions.
108 changes: 0 additions & 108 deletions tsimcne/tsimcne.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(
weight_decay=5e-4,
momentum=0.9,
warmup_epochs=10,
dim_anneal_strategy="pca",
batches_per_epoch=None,
random_state=None,
save_intermediate_feat=False,
Expand Down Expand Up @@ -79,7 +78,6 @@ def __init__(
self.weight_decay = weight_decay
self.momentum = momentum
self.warmup_epochs = warmup_epochs
self.dim_anneal_strategy = dim_anneal_strategy
self.batches_per_epoch = batches_per_epoch
self.random_state = random_state
self.save_intermediate_feat = save_intermediate_feat
Expand Down Expand Up @@ -226,21 +224,6 @@ def _handle_parameters(self):
f"got {self.optimizer_name}."
)

# if self.lr_scheduler_name not in ["cos_annealing", "constant"]:
# raise ValueError(
# "Only 'cos_annealing' and 'constant' is supported as "
# f"learning rate scheduler, got {self.lr_scheduler_name}."
# )

if (
self.dim_anneal_strategy is not None
and self.dim_anneal_strategy != "pca"
):
raise ValueError(
"Expected None or 'pca' for dim_anneal strategy, got "
f"{self.dim_anneal_strategy}."
)

self.alphas = torch.sin(
torch.linspace(0, 1, self.n_epochs) * torch.pi / 2
)
Expand Down Expand Up @@ -329,100 +312,9 @@ def on_train_epoch_start(self):
else:
self.log("dof", self.cur_dof, prog_bar=False)

prev_dim_mask = self.dim_mask
self.dim_mask = self.dim_mask_schedule[self.current_epoch]
self.cur_dof = self.dofs[self.current_epoch]

_embeddings = self._embeddings = self.train_embeddings

# next_output_dim = self.out_dim + self.dim_mask.stop
# vv and next_output_dim < 10
do_pca = (
self.dim_anneal_strategy == "pca"
and prev_dim_mask != self.dim_mask
and not isinstance(self.model.projection_head, torch.nn.Identity)
)
if do_pca:
layer = self.model.projection_head.layers[2]
self.weights = weights = layer.weight[prev_dim_mask]
# bias = layer.bias[prev_dim_mask]

# weights.T.cpu().detach()
embs = torch.vstack(_embeddings).cpu().detach().float()
unused_w = (
layer.weight[self.dim_mask.stop :]
.cpu()
.detach()
.float()
.T.numpy()
)
# unused_b = (
# layer.bias[self.dim_mask.stop :].cpu().detach().float().numpy()
# )

from sklearn.decomposition import PCA

self.pca = pca = PCA(
min(self.out_dim, self.out_dim + self.dim_mask.stop)
).fit(embs)
# make_pipeline(
# StandardScaler(with_std=False),
# TruncatedSVD(
# min(self.out_dim, self.out_dim + self.dim_mask.stop)
# ),
# )

# _pca_w = pca[1].transform(weights.T.cpu().detach().float().numpy())
_pca_w = (
weights.T.cpu().detach().float().numpy() @ pca.components_.T
)
# _e = embs.numpy()
# _w = weights.detach().cpu().numpy()
rotated_w = _pca_w # * _w.std() # + _w.mean()
transformed_weight = np.hstack((rotated_w, unused_w))

odict = self.optimizers().optimizer.state
mdict = odict[layer.weight]
# mdict.clear()
# print(len(odict.keys()), odict[weights])
## attempt at pca-transforming the momentum
mbuf = mdict["momentum_buffer"]
momentum = mdict["momentum_buffer"][prev_dim_mask]
unused_m = (
mdict["momentum_buffer"][self.dim_mask.stop :]
.cpu()
.detach()
.float()
.T.numpy()
)
_pca_m = (
momentum.T.cpu().detach().float().numpy() @ pca.components_.T
)
# _m = mbuf.detach().cpu().numpy()
rotated_m = _pca_m
transformed_mbuf = np.hstack((rotated_m, unused_m))
mdict["momentum_buffer"][:] = torch.from_numpy(
transformed_mbuf.T
).to(dtype=mbuf.dtype)

# rotated_b = pca.transform(
# np.array([bias.cpu().detach().float().numpy()])
# ).squeeze()
# transformed_bias = np.hstack((rotated_b, unused_b))

self.pca = pca
self.rotated_w = rotated_w
self.transformed_weight = transformed_weight

# raise RuntimeError("now go inspect")
sd = dict(
weight=torch.from_numpy(transformed_weight.T).to(
dtype=weights.dtype
),
# bias=torch.from_numpy(transformed_bias).bfloat16(),
)
layer.load_state_dict(sd, strict=False)

# reset the train embeddings for the next epoch
self.train_embeddings = []

Expand Down

0 comments on commit f62b7b5

Please sign in to comment.