-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bugfixes and documentation #1
base: main
Are you sure you want to change the base?
Conversation
J-Gann
commented
Jul 24, 2023
- Added documentation for pruning of pretrained models
- Fixed bugs regarding pruning of non-resnet models
- Problems in traversal of network layers while collecting layer information
- Problems where last network layer was pruned
- Added new method for loading pretrained models
- Fixed permission issue regarding wandb on cluster
- Added reference models
@@ -16,7 +16,11 @@ def _get_info(self, layer_key: str, layer_list: list[LayerInfo], full_key, paren | |||
key_elements = layer_key.split(".") | |||
|
|||
if len(key_elements) > 1: | |||
parents = {info.var_name: info for info in layer_list if not info.is_leaf_layer} | |||
if parent_info.parent_info: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, some models threw an exception. During the traversal of the model layers, when listing the childs of a parent, all childs of childs were not excluded. This led to wrong traversal of the layer tree. I excluded all childs of childs by adding the condition and info.parent_info.var_name == parent_info.var_name
@@ -225,6 +225,8 @@ def _is_layer_compatible(self, layer_key, model_info) -> bool: | |||
return False | |||
return True | |||
if isinstance(layer_info.module, torch.nn.Conv2d): | |||
if layer_info.output_size == []: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unhandled state of output_size
I ran into.
@@ -29,6 +29,9 @@ def __init__(self, | |||
self._frozen_layers = frozen_layers | |||
self._layer_dict = {layer_key: module for layer_key, module in self._reference_model.named_modules() if | |||
not [*module.children()]} | |||
last_layer = list(self._layer_dict.keys())[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I discovered, that galen assumes, that the last layer of the model is named "fc" as stated e.g here. This leads to unexpected and difficult to resolve errors during pruning. I propose to always add the last layer of the network to the list of frozen layers. Alternatively, it should be included in the documentation.
@@ -208,7 +208,7 @@ def parse_arguments() -> Namespace: | |||
log_file_name=args.log_name | |||
) | |||
|
|||
wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=vars(args), | |||
wandb.init(project=args.wandb_project, config=vars(args), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
entity=args.wandb_entity
leads to permission problems on the cluster if not running as superuser. The reason is, that wandb tries to access the /tmp
folder. Deleting this argument resolves the problem.
name, repo = select_str.split("@") | ||
model = torch.hub.load(repo, name, pretrained=True, num_classes=num_classes) | ||
elif "/" in select_str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Proposal for an additional method for loading pretrained models which were saved using torch.save(model, PATH).
@@ -28,7 +28,7 @@ dependencies: | |||
- matplotlib | |||
- pandas | |||
- pip: | |||
- torch-pruning | |||
- torch-pruning==0.2.8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The torch-pruning API changed since version 0.2.8 requiring a refactoring of the galen code.