-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstrategies.py
201 lines (170 loc) · 8.58 KB
/
strategies.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# ===========================================================================
# Project: Sparse Model Soups: A Recipe for Improved Pruning via Model Averaging - IOL Lab @ ZIB
# Paper: arxiv.org/abs/2306.16788
# File: strategies/strategies.py
# Description: Sparsification strategies for regular training
# ===========================================================================
import sys
from collections import OrderedDict
import torch
import torch.nn.utils.prune as prune
#### Dense Base Class
class Dense:
"""Dense base class for defining callbacks, does nothing but showing the structure and inherits."""
required_params = []
def __init__(self, **kwargs):
self.masks = dict()
self.lr_dict = OrderedDict() # it:lr
self.is_in_finetuning_phase = False
self.model = kwargs['model']
self.run_config = kwargs['config']
self.callbacks = kwargs['callbacks']
self.goal_sparsity = self.run_config['goal_sparsity']
self.optimizer = None # To be set
self.n_total_iterations = None
def after_initialization(self):
"""Called after initialization of the strategy"""
self.parameters_to_prune = [(module, 'weight') for name, module in self.model.named_modules() if
hasattr(module, 'weight')
and not isinstance(module.weight, type(None)) and not isinstance(module,
torch.nn.BatchNorm2d)]
self.n_prunable_parameters = sum(
getattr(module, param_type).numel() for module, param_type in self.parameters_to_prune)
def set_optimizer(self, opt, **kwargs):
self.optimizer = opt
if 'n_total_iterations' in kwargs:
self.n_total_iterations = kwargs['n_total_iterations']
@torch.no_grad()
def after_training_iteration(self, **kwargs):
"""Called after each training iteration"""
if not self.is_in_finetuning_phase:
self.lr_dict[kwargs['it']] = kwargs['lr']
def at_train_begin(self):
"""Called before training begins"""
pass
def at_epoch_start(self, **kwargs):
"""Called before the epoch starts"""
pass
def at_epoch_end(self, **kwargs):
"""Called at epoch end"""
pass
def at_train_end(self, **kwargs):
"""Called at the end of training"""
pass
def final(self):
pass
@torch.no_grad()
def pruning_step(self, pruning_sparsity, only_save_mask=False, compute_from_scratch=False):
if compute_from_scratch:
# We have to revert to weight_orig and then compute the mask
for module, param_type in self.parameters_to_prune:
if prune.is_pruned(module):
# Enforce the equivalence of weight_orig and weight
orig = getattr(module, param_type + "_orig").detach().clone()
prune.remove(module, param_type)
p = getattr(module, param_type)
p.copy_(orig)
del orig
elif only_save_mask and len(self.masks) > 0:
for module, param_type in self.parameters_to_prune:
if (module, param_type) in self.masks:
prune.custom_from_mask(module, param_type, self.masks[(module, param_type)])
if self.run_config['pruning_selector'] is not None and self.run_config['pruning_selector'] == 'uniform':
# We prune each layer individually
for module, param_type in self.parameters_to_prune:
prune.l1_unstructured(module, name=param_type, amount=pruning_sparsity)
else:
# Default: prune globally
prune.global_unstructured(
self.parameters_to_prune,
pruning_method=self.get_pruning_method(),
amount=pruning_sparsity,
)
self.masks = dict() # Stays empty if we use regular pruning
if only_save_mask:
for module, param_type in self.parameters_to_prune:
if prune.is_pruned(module):
# Save the mask
mask = getattr(module, param_type + '_mask')
self.masks[(module, param_type)] = mask.detach().clone()
setattr(module, param_type + '_mask', torch.ones_like(mask))
# Remove (i.e. make permanent) the reparameterization
prune.remove(module=module, name=param_type)
# Delete the temporary mask to free memory
del mask
def enforce_prunedness(self):
"""
Makes the pruning permanent, i.e. set the pruned weights to zero, than reinitialize from the same mask
This ensures that we can actually work (i.e. LMO, rescale computation) with the parameters
Important: For this to work we require that pruned weights stay zero in weight_orig over training
hence training, projecting etc should not modify (pruned) 0 weights in weight_orig
"""
for module, param_type in self.parameters_to_prune:
if prune.is_pruned(module):
# Save the mask
mask = getattr(module, param_type + '_mask')
# Remove (i.e. make permanent) the reparameterization
prune.remove(module=module, name=param_type)
# Reinitialize the pruning
prune.custom_from_mask(module=module, name=param_type, mask=mask)
# Delete the temporary mask to free memory
del mask
def prune_momentum(self):
opt_state = self.optimizer.state
for module, param_type in self.parameters_to_prune:
if prune.is_pruned(module):
# Enforce the prunedness of momentum buffer
param_state = opt_state[getattr(module, param_type + "_orig")]
if 'momentum_buffer' in param_state:
mask = getattr(module, param_type + "_mask")
param_state['momentum_buffer'] *= mask.to(dtype=param_state['momentum_buffer'].dtype)
def get_pruning_method(self):
raise NotImplementedError("Dense has no pruning method, this must be implemented in each child class.")
@torch.no_grad()
def make_pruning_permanent(self):
"""Makes the pruning permanent and removes the pruning hooks"""
# Note: this does not remove the pruning itself, but rather makes it permanent
if len(self.masks) == 0:
for module, param_type in self.parameters_to_prune:
if prune.is_pruned(module):
prune.remove(module, param_type)
else:
for module, param_type in self.masks:
# Get the mask
mask = self.masks[(module, param_type)]
# Apply the mask
orig = getattr(module, param_type)
orig *= mask
self.masks = dict()
def set_to_finetuning_phase(self):
self.is_in_finetuning_phase = True
class IMP(Dense):
"""Iterative Magnitude Pruning Base Class"""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.phase = self.run_config['phase']
self.n_phases = self.run_config['n_phases']
self.n_epochs_per_phase = self.run_config['n_epochs_per_phase']
def at_train_end(self, **kwargs):
# Sparsity factor on remaining weights after each round, yields desired_sparsity after all rounds
prune_per_phase = 1 - (1 - self.goal_sparsity) ** (1. / self.n_phases)
phase = self.phase
self.pruning_step(pruning_sparsity=prune_per_phase)
self.current_sparsity = 1 - (1 - prune_per_phase) ** phase
self.callbacks['after_pruning_callback']()
self.finetuning_step(pruning_sparsity=prune_per_phase, phase=phase)
def finetuning_step(self, pruning_sparsity, phase):
self.callbacks['finetuning_callback'](pruning_sparsity=pruning_sparsity,
n_epochs_finetune=self.n_epochs_per_phase,
phase=phase)
def get_pruning_method(self):
if self.run_config['pruning_selector'] in ['global', 'uniform']:
# For uniform this is not actually needed, we always select using L1
return prune.L1Unstructured
elif self.run_config['pruning_selector'] == 'random':
return prune.RandomUnstructured
else:
raise NotImplementedError
def final(self):
super().final()
self.callbacks['final_log_callback']()