Skip to content

Commit

Permalink
Merge pull request #21 from jcreinhold/static_upsample
Browse files Browse the repository at this point in the history
Static upsample
  • Loading branch information
jcreinhold authored Jan 11, 2022
2 parents 89ab90e + 7cd7f98 commit f5999be
Show file tree
Hide file tree
Showing 14 changed files with 376 additions and 109 deletions.
7 changes: 7 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
History
=======

0.2.0 (2022-01-11)
-------------------

* Add option to upsample with interpolation instead of transpose conv.
* Remove separate padding layer and use conv. built-in padding
* Improvements to ONNX converter

0.1.37 (2021-12-16)
-------------------

Expand Down
8 changes: 8 additions & 0 deletions docs/tiramisu_brulee.experiment.cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ tiramisu\_brulee.experiment.cli.predict module
:undoc-members:
:show-inheritance:

tiramisu\_brulee.experiment.cli.to\_onnx module
-----------------------------------------------

.. automodule:: tiramisu_brulee.experiment.cli.to_onnx
:members:
:undoc-members:
:show-inheritance:

tiramisu\_brulee.experiment.cli.train module
--------------------------------------------

Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,3 @@ profile = "black"

[tool.bandit]
skips = ["B101"]

2 changes: 2 additions & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ numpy~=1.22.0; python_version >= "3.8"
pandas~=1.1.5; python_version < "3.7"
pandas>=1.3.4; python_version >= "3.7"
pandas-stubs
pillow>=8.4.0; python_version == "3.6"
pillow>=9.0.0; python_version >= "3.7"
pytorch-lightning~=1.5.1
PyYAML>=5.4.1
ruyaml>=0.20.0
Expand Down
25 changes: 13 additions & 12 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.1.37
current_version = 0.2.0
commit = True
tag = True

Expand All @@ -9,7 +9,7 @@ replace = __version__ = "{new_version}"

[flake8]
exclude = docs
per-file-ignores =
per-file-ignores =
__init__.py: F401
max-line-length = 88
extend-ignore = E203
Expand All @@ -32,10 +32,10 @@ keywords = tiramisu, segmentation, neural network, convolutional, pytorch
license = Apache Software License 2.0
license_file = LICENSE
url = https://github.com/jcreinhold/tiramisu-brulee
project_urls =
project_urls =
Bug Tracker = https://github.com/jcreinhold/tiramisu-brulee/issues
Documentation = https://tiramisu-brulee.readthedocs.io/
classifiers =
classifiers =
Development Status :: 4 - Beta
Intended Audience :: Developers
License :: OSI Approved :: Apache Software License
Expand All @@ -54,26 +54,27 @@ zip_safe = False
include_package_data = True
packages = find:
python_requires = >= 3.6
install_requires =
install_requires =
torch
test_suite = tests

[options.packages.find]
include =
include =
tiramisu_brulee
tiramisu_brulee.*
exclude =
exclude =
tests
docs

[options.package_data]
tiramisu_brulee = py.typed

[options.extras_require]
lesionseg =
lesionseg =
jsonargparse~=3.12.0
numpy
pandas
pillow>=9.0.0
pytorch-lightning~=1.5.1
PyYAML
ruyaml
Expand All @@ -83,12 +84,12 @@ lesionseg =
torchio
torchmetrics
mlflow = mlflow
onnx =
onnx
onnxruntime
onnx =
onnx
onnxruntime

[options.entry_points]
console_scripts =
console_scripts =
lesion-train = tiramisu_brulee.experiment.cli.train:train
lesion-predict = tiramisu_brulee.experiment.cli.predict:predict
lesion-predict-image = tiramisu_brulee.experiment.cli.predict:predict_image
Expand Down
2 changes: 1 addition & 1 deletion tiramisu_brulee/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
__url__ = "https://github.com/jcreinhold/tiramisu-brulee"
__author__ = """Jacob Reinhold"""
__email__ = "[email protected]"
__version__ = "0.1.37"
__version__ = "0.2.0"
__license__ = "Apache-2.0"
__copyright__ = "Copyright 2021 Jacob Reinhold"

Expand Down
50 changes: 46 additions & 4 deletions tiramisu_brulee/experiment/cli/to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,29 @@ def arg_parser() -> ArgParser:
action="store_true",
help="don't add metadata",
)
parser.add_argument(
"--no-dynamic-batch",
action="store_true",
help="don't use dynamic batches",
)
parser.add_argument(
"--no-dynamic-shape",
action="store_true",
help="don't use dynamic shapes",
)
parser.add_argument(
"--batch-size",
default=1,
type=int,
help="input batch size (important if no-dynamic-batch enabled)",
)
parser.add_argument(
"--image-shape",
default=None,
nargs="+",
type=int,
help="input image shape (important if no-dynamic-shape enabled)",
),
parser.add_argument(
"-v",
"--verbosity",
Expand Down Expand Up @@ -160,8 +183,12 @@ def to_onnx(args: ArgType = None) -> builtins.int:
root, base, ext = split_filename(args.onnx_path)
onnx_path = root / (base + f"_{i}" + ext)
nth_model = f" ({i}/{n_models})" if n_models > 1 else ""
if args.image_shape is None:
args.image_shape = (128,) * (3 if p3d is None else 2)
model = LesionSegLightningTiramisu.load_from_checkpoint(
str(model_path),
input_shape=args.image_shape,
static_upsample=args.no_dynamic_shape,
_model_num=model_num,
)
if args.prune:
Expand All @@ -176,11 +203,20 @@ def to_onnx(args: ArgType = None) -> builtins.int:
parameters_to_prune, prune.L1Unstructured, amount=args.prune_amount
)
n_channels = n_inputs if p3d is None else (args.pseudo3d_size * n_inputs)
input_shape = (1, n_channels) + (128,) * (3 if p3d is None else 2)
input_shape = (args.batch_size, n_channels) + tuple(args.image_shape)
logger.debug(f"Input shape: {input_shape}")
input_sample = torch.randn(input_shape)
axes = {0: "batch_size", 2: "h", 3: "w"}
if p3d is None:
axes = dict()
if not args.no_dynamic_batch:
axes.update({0: "batch_size"})
if not args.no_dynamic_shape:
axes.update({2: "h", 3: "w"})
if not args.no_dynamic_shape and p3d is None:
axes.update({4: "d"})
if not args.no_dynamic_batch or not args.no_dynamic_shape:
dynamic_axes = {"input": axes, "output": axes}
else:
dynamic_axes = None
with tempfile.NamedTemporaryFile("w") as f:
save_as_ort = str(onnx_path).endswith(".ort")
file_path = f.name if save_as_ort else onnx_path
Expand All @@ -193,7 +229,8 @@ def to_onnx(args: ArgType = None) -> builtins.int:
do_constant_folding=args.do_constant_folding,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": axes, "output": axes},
dynamic_axes=dynamic_axes,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
)
logger.info("Exporting model to ONNX" + nth_model)
if args.verbosity >= 3:
Expand Down Expand Up @@ -290,6 +327,11 @@ def add_metadata(
if args.pseudo3d_dim is not None:
doc_string += f" p3d:{args.pseudo3d_dim[i]}"
doc_string += f" p3s:{args.pseudo3d_size}"
if args.no_dynamic_batch:
doc_string += f" static-batch-size={args.batch_size}"
if args.no_dynamic_shape:
image_shape = str(args.image_shape).replace(" ", "")
doc_string += f" static-shape={image_shape}"
model = onnx.load_model(onnx_model_path)
model.producer_name = producer_name
model.producer_version = producer_version
Expand Down
1 change: 1 addition & 0 deletions tiramisu_brulee/experiment/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def __init__( # type: ignore[no-untyped-def]
super().__init__(*args, **kwargs)
self.mlflow_logger = mlflow_logger

# flake8: noqa: E501
def save_checkpoint(
self,
trainer: Trainer,
Expand Down
53 changes: 46 additions & 7 deletions tiramisu_brulee/experiment/seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
]

import builtins
import enum
import functools
import logging
import typing
Expand Down Expand Up @@ -60,12 +61,31 @@
l1_segmentation_loss,
mse_segmentation_loss,
)
from tiramisu_brulee.model import Tiramisu2d, Tiramisu3d
from tiramisu_brulee.util import init_weights
from tiramisu_brulee.model import ResizeMethod, Tiramisu2d, Tiramisu3d
from tiramisu_brulee.util import InitType, init_weights

PredictBatch = typing.Union[PatchesImagePredictBatch, WholeImagePredictBatch]


@enum.unique
class LossFunction(enum.Enum):
COMBO: builtins.str = "combo"
L1: builtins.str = "l1"
MSE: builtins.str = "mse"

@classmethod
def from_string(cls, string: builtins.str) -> "LossFunction":
if string.lower() == "combo":
return cls.COMBO
elif string.lower() == "l1":
return cls.L1
elif string.lower() == "mse":
return cls.MSE
else:
msg = f"Only 'combo', 'l1', 'mse' allowed. Got {string}"
raise ValueError(msg)


class LesionSegLightningBase(pl.LightningModule):
"""PyTorch-Lightning module for lesion segmentation
Expand Down Expand Up @@ -144,7 +164,10 @@ def setup(self, stage: typing.Optional[builtins.str] = None) -> None:
self.criterion: typing.Callable
num_classes = self.hparams.num_classes
assert isinstance(num_classes, builtins.int)
if self.hparams.loss_function == "combo":
loss_func_str = self.hparams.loss_function
assert isinstance(loss_func_str, builtins.str)
loss_func = LossFunction.from_string(loss_func_str)
if loss_func == LossFunction.COMBO:
if self.hparams.num_classes == 1:
self.criterion = functools.partial(
binary_combo_loss,
Expand All @@ -161,10 +184,10 @@ def setup(self, stage: typing.Optional[builtins.str] = None) -> None:
else:
msg = f"num_classes must be greater than zero. Got {self.num_classes}."
raise ValueError(msg)
elif self.hparams.loss_function == "mse":
self.criterion = mse_segmentation_loss
elif self.hparams.loss_function == "l1":
elif loss_func == LossFunction.L1:
self.criterion = l1_segmentation_loss
elif loss_func == LossFunction.MSE:
self.criterion = mse_segmentation_loss
else:
raise ValueError(f"{self.hparams.loss_function} not supported.")
use_mixup = bool(self.hparams.mixup)
Expand Down Expand Up @@ -704,6 +727,7 @@ def add_testing_arguments(parent_parser: ArgParser) -> ArgParser:
return parent_parser


# flake8: noqa: E501
class LesionSegLightningTiramisu(LesionSegLightningBase):
"""3D Tiramisu-based PyTorch-Lightning module for lesion segmentation
Expand Down Expand Up @@ -785,6 +809,9 @@ def __init__( # type: ignore[no-untyped-def]
mixup: builtins.bool = False,
mixup_alpha: builtins.float = 0.4,
num_input: builtins.int = 1,
resize_method: builtins.str = "crop",
input_shape: typing.Optional[typing.Tuple[builtins.int, ...]] = None,
static_upsample: builtins.bool = True,
_model_num: ModelNum = ModelNum(1, 1),
**kwargs,
):
Expand All @@ -804,8 +831,11 @@ def __init__( # type: ignore[no-untyped-def]
growth_rate=growth_rate,
first_conv_out_channels=first_conv_out_channels,
dropout_rate=dropout_rate,
resize_method=ResizeMethod.from_string(resize_method),
input_shape=input_shape,
static_upsample=static_upsample,
)
init_weights(network, init_type=init_type, gain=gain)
init_weights(network, init_type=InitType.from_string(init_type), gain=gain)
super().__init__(
network=network,
n_epochs=n_epochs,
Expand Down Expand Up @@ -915,4 +945,13 @@ def add_model_arguments(parent_parser: ArgParser) -> ArgParser:
default=48,
help="number of output channels in first conv",
)
parser.add_argument(
"-rm",
"--resize-method",
type=str,
default="crop",
choices=("crop", "interpolate"),
help="use transpose conv and crop or normal conv "
"and interpolate to correct size in upsample branch",
)
return parent_parser
2 changes: 2 additions & 0 deletions tiramisu_brulee/experiment/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
PathLike = typing.Union[builtins.str, os.PathLike]


# flake8: noqa: E501
def return_none(func: typing.Callable) -> typing.Callable:
def new_func(self, string: typing.Any) -> typing.Any: # type: ignore[no-untyped-def]
if string is None:
Expand All @@ -78,6 +79,7 @@ def new_func(self, string: typing.Any) -> typing.Any: # type: ignore[no-untyped
return new_func


# flake8: noqa: E501
def return_str(match_string: builtins.str) -> typing.Callable:
def decorator(func: typing.Callable) -> typing.Callable:
def new_func(self, string: typing.Any) -> typing.Any: # type: ignore[no-untyped-def]
Expand Down
1 change: 1 addition & 0 deletions tiramisu_brulee/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from tiramisu_brulee.model.dense import ResizeMethod
from tiramisu_brulee.model.tiramisu import Tiramisu2d, Tiramisu3d
Loading

0 comments on commit f5999be

Please sign in to comment.