From f62b7b53361b918f07c6461597f476cb9b435c2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niklas=20B=C3=B6hm?= Date: Thu, 12 Dec 2024 21:41:05 +0100 Subject: [PATCH] Remove PCA reweighting --- tsimcne/tsimcne.py | 108 --------------------------------------------- 1 file changed, 108 deletions(-) diff --git a/tsimcne/tsimcne.py b/tsimcne/tsimcne.py index abbb328..de5885c 100644 --- a/tsimcne/tsimcne.py +++ b/tsimcne/tsimcne.py @@ -44,7 +44,6 @@ def __init__( weight_decay=5e-4, momentum=0.9, warmup_epochs=10, - dim_anneal_strategy="pca", batches_per_epoch=None, random_state=None, save_intermediate_feat=False, @@ -79,7 +78,6 @@ def __init__( self.weight_decay = weight_decay self.momentum = momentum self.warmup_epochs = warmup_epochs - self.dim_anneal_strategy = dim_anneal_strategy self.batches_per_epoch = batches_per_epoch self.random_state = random_state self.save_intermediate_feat = save_intermediate_feat @@ -226,21 +224,6 @@ def _handle_parameters(self): f"got {self.optimizer_name}." ) - # if self.lr_scheduler_name not in ["cos_annealing", "constant"]: - # raise ValueError( - # "Only 'cos_annealing' and 'constant' is supported as " - # f"learning rate scheduler, got {self.lr_scheduler_name}." - # ) - - if ( - self.dim_anneal_strategy is not None - and self.dim_anneal_strategy != "pca" - ): - raise ValueError( - "Expected None or 'pca' for dim_anneal strategy, got " - f"{self.dim_anneal_strategy}." - ) - self.alphas = torch.sin( torch.linspace(0, 1, self.n_epochs) * torch.pi / 2 ) @@ -329,100 +312,9 @@ def on_train_epoch_start(self): else: self.log("dof", self.cur_dof, prog_bar=False) - prev_dim_mask = self.dim_mask self.dim_mask = self.dim_mask_schedule[self.current_epoch] self.cur_dof = self.dofs[self.current_epoch] - _embeddings = self._embeddings = self.train_embeddings - - # next_output_dim = self.out_dim + self.dim_mask.stop - # vv and next_output_dim < 10 - do_pca = ( - self.dim_anneal_strategy == "pca" - and prev_dim_mask != self.dim_mask - and not isinstance(self.model.projection_head, torch.nn.Identity) - ) - if do_pca: - layer = self.model.projection_head.layers[2] - self.weights = weights = layer.weight[prev_dim_mask] - # bias = layer.bias[prev_dim_mask] - - # weights.T.cpu().detach() - embs = torch.vstack(_embeddings).cpu().detach().float() - unused_w = ( - layer.weight[self.dim_mask.stop :] - .cpu() - .detach() - .float() - .T.numpy() - ) - # unused_b = ( - # layer.bias[self.dim_mask.stop :].cpu().detach().float().numpy() - # ) - - from sklearn.decomposition import PCA - - self.pca = pca = PCA( - min(self.out_dim, self.out_dim + self.dim_mask.stop) - ).fit(embs) - # make_pipeline( - # StandardScaler(with_std=False), - # TruncatedSVD( - # min(self.out_dim, self.out_dim + self.dim_mask.stop) - # ), - # ) - - # _pca_w = pca[1].transform(weights.T.cpu().detach().float().numpy()) - _pca_w = ( - weights.T.cpu().detach().float().numpy() @ pca.components_.T - ) - # _e = embs.numpy() - # _w = weights.detach().cpu().numpy() - rotated_w = _pca_w # * _w.std() # + _w.mean() - transformed_weight = np.hstack((rotated_w, unused_w)) - - odict = self.optimizers().optimizer.state - mdict = odict[layer.weight] - # mdict.clear() - # print(len(odict.keys()), odict[weights]) - ## attempt at pca-transforming the momentum - mbuf = mdict["momentum_buffer"] - momentum = mdict["momentum_buffer"][prev_dim_mask] - unused_m = ( - mdict["momentum_buffer"][self.dim_mask.stop :] - .cpu() - .detach() - .float() - .T.numpy() - ) - _pca_m = ( - momentum.T.cpu().detach().float().numpy() @ pca.components_.T - ) - # _m = mbuf.detach().cpu().numpy() - rotated_m = _pca_m - transformed_mbuf = np.hstack((rotated_m, unused_m)) - mdict["momentum_buffer"][:] = torch.from_numpy( - transformed_mbuf.T - ).to(dtype=mbuf.dtype) - - # rotated_b = pca.transform( - # np.array([bias.cpu().detach().float().numpy()]) - # ).squeeze() - # transformed_bias = np.hstack((rotated_b, unused_b)) - - self.pca = pca - self.rotated_w = rotated_w - self.transformed_weight = transformed_weight - - # raise RuntimeError("now go inspect") - sd = dict( - weight=torch.from_numpy(transformed_weight.T).to( - dtype=weights.dtype - ), - # bias=torch.from_numpy(transformed_bias).bfloat16(), - ) - layer.load_state_dict(sd, strict=False) - # reset the train embeddings for the next epoch self.train_embeddings = []