Skip to content

Commit

Permalink
Merge pull request #20 from daniel-code/feat/pytorch_models
Browse files Browse the repository at this point in the history
Feat/pytorch models
  • Loading branch information
daniel-code authored Nov 3, 2022
2 parents bc8d8b7 + 8cc4685 commit 5ffbe49
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 77 deletions.
15 changes: 9 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,19 @@ python train.py -r "datasets/final/train"
python train.py -r "datasets/final/train" --user-pretrained-weight --finetune-last-layer --use-lr-scheduler --use-auto-augment
```

- Training with different model types. See more details in `scripts/different_models.sh`
- Training with different model types. See more details in `scripts/different_models.sh`.
Support [pytorch built-in model types](https://pytorch.org/vision/main/models.html#classification).

```commandline
python train.py -r "datasets/final/train" --model-type resnext50_32x4d
```

Support model types:
- Training with different image size. Some model has image resolution constraint, e.g. vit, only accept image size by (
244, 244).

- ResNet: resnet18, resnet34, resnet_50, resnet_101
- ResNext: resnext50_32x4d, resnext101_32x8d
- Swin: swin_t, swin_s, swin_b
```commandline
python train.py -r "datasets/final/train" --model-type vit_b_16 --image-size 224 224
```

After training, the model weight will export to `model_weights/<model-type>_<exp_time>`.
Use `tensorboard --logdir model_weights` to browse training log.
Expand Down Expand Up @@ -159,7 +161,8 @@ Options:
python analysis.py -r "datasets/final/train" --model-path "model_weights/<model-type>_<exp_time>/model.pt"
```

By default, the `reports/test.png` is AUC of ROC curve and confusion matrix, and the `reports/test_images.jpg` shows the fail cases.
By default, the `reports/test.png` is AUC of ROC curve and confusion matrix, and the `reports/test_images.jpg` shows the
fail cases.

## Inference

Expand Down
7 changes: 3 additions & 4 deletions dogs_cats_classifier/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .resnet import ResNet
from .resnext import ResNext
from .swin import Swin
from .base import ModelBase
from .torch_model_wrapper import is_torch_builtin_models, TorchModelWrapper

__all__ = ['ResNet', 'Swin', 'ResNext']
__all__ = ['ModelBase', 'TorchModelWrapper', 'is_torch_builtin_models']
5 changes: 0 additions & 5 deletions dogs_cats_classifier/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ def __init__(self,
self.user_pretrained_weight = user_pretrained_weight
self.finetune_last_layer = finetune_last_layer

self.models_mapping = self._setup_models_mapping()
assert model_type in self.models_mapping, f'{model_type} is not available. There is available model types: {list(self.models_mapping.keys())}'
self.model = self._setup_model(model_type=model_type)

if finetune_last_layer:
Expand All @@ -39,9 +37,6 @@ def __init__(self,

self.example_input_array = torch.zeros((1, 3, input_shape[0], input_shape[1]), dtype=torch.float32)

def _setup_models_mapping(self) -> dict:
raise NotImplementedError

def _setup_model(self, model_type) -> torch.nn.Module:
raise NotImplementedError

Expand Down
20 changes: 0 additions & 20 deletions dogs_cats_classifier/models/resnet.py

This file was deleted.

11 changes: 0 additions & 11 deletions dogs_cats_classifier/models/resnext.py

This file was deleted.

19 changes: 0 additions & 19 deletions dogs_cats_classifier/models/swin.py

This file was deleted.

40 changes: 40 additions & 0 deletions dogs_cats_classifier/models/torch_model_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import types

from .base import ModelBase
from torch.nn import Module, Sequential, Linear


def is_torch_builtin_models(model_type: str):
model_type = model_type.lower()
model_module = __import__('torchvision.models', fromlist=(model_type, ), level=0)
if model_type in dir(model_module):
model_func = getattr(model_module, model_type)
return True if isinstance(model_func, types.FunctionType) else False

return False


class TorchModelWrapper(ModelBase):
def _setup_model(self, model_type) -> Module:
# dynamic import module
model_module = __import__('torchvision.models', fromlist=(model_type, ), level=0)
model_func = getattr(model_module, model_type)
model = model_func(weights='DEFAULT' if self.user_pretrained_weight else None)

# get last layer's name
layer_name = list(model.named_children())[-1][0]

# check last layer
last_layer = model.get_submodule(layer_name)
if isinstance(last_layer, Sequential):
last_layer = last_layer[-1]
in_features = last_layer.in_features

# replace the last layer
last_layer = getattr(model, layer_name)
if isinstance(last_layer, Sequential):
last_layer[-1] = Linear(in_features=in_features, out_features=self.num_classes)
else:
setattr(model, layer_name, Linear(in_features=in_features, out_features=self.num_classes))

return model
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name='dogs_cats_classifier',
packages=find_packages(),
version='0.1.0',
version='0.1.1',
description='Create an algorithm to distinguish dogs from cats',
author='YanRu',
license="MIT",
Expand Down
19 changes: 8 additions & 11 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pytorch_lightning.callbacks import LearningRateMonitor

from dogs_cats_classifier.data import DogsCatsImagesDataModule
from dogs_cats_classifier.models import ResNet, Swin, ResNext
from dogs_cats_classifier.models import TorchModelWrapper, is_torch_builtin_models
from dogs_cats_classifier.utils import Evaluator
from datetime import datetime

Expand Down Expand Up @@ -78,16 +78,12 @@ def main(batch_size, max_epochs, num_workers, image_size, dataset_root, fast_dev
print(dogs_cats_datamodule)

# prepare model
if 'swin' in model_type:
model = Swin
elif 'resnext' in model_type:
model = ResNext
elif 'resnet' in model_type:
model = ResNet
if is_torch_builtin_models(model_type):
model_class = TorchModelWrapper
else:
raise ValueError(f'{model_type} is not available.')

model = model(
model = model_class(
num_classes=1,
model_type=model_type,
input_shape=image_size,
Expand Down Expand Up @@ -117,9 +113,10 @@ def main(batch_size, max_epochs, num_workers, image_size, dataset_root, fast_dev
torch.jit.save(script_model, os.path.join(output_path, 'model.pt'))

# evaluation
dogs_cats_datamodule.setup()
evaluator = Evaluator(model=model, output_path=output_path)
evaluator.evaluate(dataloader=dogs_cats_datamodule.test_dataloader(), title=f'{model_type}_test', verbose=False)
if not fast_dev_run:
dogs_cats_datamodule.setup()
evaluator = Evaluator(model=model, output_path=output_path)
evaluator.evaluate(dataloader=dogs_cats_datamodule.test_dataloader(), title=f'{model_type}_test', verbose=False)


if __name__ == '__main__':
Expand Down

0 comments on commit 5ffbe49

Please sign in to comment.