Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfixes and documentation #1

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open

Bugfixes and documentation #1

wants to merge 14 commits into from

Conversation

J-Gann
Copy link

@J-Gann J-Gann commented Jul 24, 2023

  • Added documentation for pruning of pretrained models
  • Fixed bugs regarding pruning of non-resnet models
    • Problems in traversal of network layers while collecting layer information
    • Problems where last network layer was pruned
  • Added new method for loading pretrained models
  • Fixed permission issue regarding wandb on cluster
  • Added reference models

@@ -16,7 +16,11 @@ def _get_info(self, layer_key: str, layer_list: list[LayerInfo], full_key, paren
key_elements = layer_key.split(".")

if len(key_elements) > 1:
parents = {info.var_name: info for info in layer_list if not info.is_leaf_layer}
if parent_info.parent_info:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, some models threw an exception. During the traversal of the model layers, when listing the childs of a parent, all childs of childs were not excluded. This led to wrong traversal of the layer tree. I excluded all childs of childs by adding the condition and info.parent_info.var_name == parent_info.var_name

@@ -225,6 +225,8 @@ def _is_layer_compatible(self, layer_key, model_info) -> bool:
return False
return True
if isinstance(layer_info.module, torch.nn.Conv2d):
if layer_info.output_size == []:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unhandled state of output_size I ran into.

@@ -29,6 +29,9 @@ def __init__(self,
self._frozen_layers = frozen_layers
self._layer_dict = {layer_key: module for layer_key, module in self._reference_model.named_modules() if
not [*module.children()]}
last_layer = list(self._layer_dict.keys())[-1]
Copy link
Author

@J-Gann J-Gann Jul 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I discovered, that galen assumes, that the last layer of the model is named "fc" as stated e.g here. This leads to unexpected and difficult to resolve errors during pruning. I propose to always add the last layer of the network to the list of frozen layers. Alternatively, it should be included in the documentation.

@@ -208,7 +208,7 @@ def parse_arguments() -> Namespace:
log_file_name=args.log_name
)

wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=vars(args),
wandb.init(project=args.wandb_project, config=vars(args),
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

entity=args.wandb_entity leads to permission problems on the cluster if not running as superuser. The reason is, that wandb tries to access the /tmp folder. Deleting this argument resolves the problem.

name, repo = select_str.split("@")
model = torch.hub.load(repo, name, pretrained=True, num_classes=num_classes)
elif "/" in select_str:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Proposal for an additional method for loading pretrained models which were saved using torch.save(model, PATH).

@@ -28,7 +28,7 @@ dependencies:
- matplotlib
- pandas
- pip:
- torch-pruning
- torch-pruning==0.2.8
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The torch-pruning API changed since version 0.2.8 requiring a refactoring of the galen code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant