Skip to content

Commit

Permalink
crop resize and working gpu augmentations
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanklut committed Jul 16, 2024
1 parent c154ff6 commit 60cd4a0
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 61 deletions.
1 change: 1 addition & 0 deletions configs/extra_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@

# Resize options after loading image
_C.INPUT.RESIZE_MODE = "shortest_edge"
_C.INPUT.CROP_RESIZE = False
_C.INPUT.SCALING = 0.0
_C.INPUT.SCALING_TRAIN = 0.5
_C.INPUT.SCALING_TEST = 0.0
Expand Down
76 changes: 68 additions & 8 deletions core/trainer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import copy
import itertools
import logging
import os
import weakref
from typing import Any, Callable, Dict, List, Optional, Set
from typing import Any, Callable, Dict, List, Optional, OrderedDict, Set

import torch
from detectron2.checkpoint import DetectionCheckpointer
Expand All @@ -20,12 +21,18 @@
create_ddp_model,
hooks,
)
from detectron2.evaluation import SemSegEvaluator
from detectron2.evaluation import (
DatasetEvaluator,
SemSegEvaluator,
inference_on_dataset,
print_csv_format,
)
from detectron2.projects.deeplab import build_lr_scheduler
from detectron2.solver.build import maybe_add_gradient_clipping, reduce_param_groups
from detectron2.utils import comm

from data.mapper import Mapper
from utils.logging_utils import get_logger_name


def get_default_optimizer_params(
Expand Down Expand Up @@ -219,9 +226,10 @@ def __init__(self, cfg: CfgNode, validation: bool = False):
# Assume these objects must be constructed in this order.
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
data_loader = self.build_train_loader(cfg) if not validation else None

model = create_ddp_model(model, broadcast_buffers=False)

data_loader = self.build_train_loader(cfg, device=model.device) if not validation else None

self._trainer = (AMPTrainer if cfg.MODEL.AMP_TRAIN.ENABLED else SimpleTrainer)(model, data_loader, optimizer)
if isinstance(self._trainer, AMPTrainer):
precision_converter = {
Expand Down Expand Up @@ -295,18 +303,18 @@ def build_evaluator(cls, cfg, dataset_name):
return evaluator

@classmethod
def build_train_loader(cls, cfg):
def build_train_loader(cls, cfg, device=torch.device("cpu")):
if cfg.MODEL.META_ARCHITECTURE in ["SemanticSegmentor", "MaskFormer", "PanopticFPN"]:
mapper = Mapper(cfg, mode="train") # type: ignore
mapper = Mapper(cfg, mode="train", device=device) # type: ignore
else:
raise NotImplementedError(f"Current META_ARCHITECTURE type {cfg.MODEL.META_ARCHITECTURE} not supported")

return build_detection_train_loader(cfg=cfg, mapper=mapper, pin_memory=cfg.DATALOADER.PIN_MEMORY) # type: ignore

@classmethod
def build_test_loader(cls, cfg, dataset_name):
def build_test_loader(cls, cfg, dataset_name, device=torch.device("cpu")):
if cfg.MODEL.META_ARCHITECTURE in ["SemanticSegmentor", "MaskFormer", "PanopticFPN"]:
mapper = Mapper(cfg, mode="val") # type: ignore
mapper = Mapper(cfg, mode="val", device=device) # type: ignore
else:
raise NotImplementedError(f"Current META_ARCHITECTURE type {cfg.MODEL.META_ARCHITECTURE} not supported")

Expand All @@ -320,5 +328,57 @@ def build_optimizer(cls, cfg, model):
def build_lr_scheduler(cls, cfg, optimizer):
return build_lr_scheduler(cfg, optimizer)

@classmethod
def test(cls, cfg, model, evaluators=None):
"""
Evaluate the given model. The given model is expected to already contain
weights to evaluate.
Args:
cfg (CfgNode):
model (nn.Module):
evaluators (list[DatasetEvaluator] or None): if None, will call
:meth:`build_evaluator`. Otherwise, must have the same length as
``cfg.DATASETS.TEST``.
Returns:
dict: a dict of result metrics
"""
logger = logging.getLogger(get_logger_name())
if isinstance(evaluators, DatasetEvaluator):
evaluators = [evaluators]
if evaluators is not None:
assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(len(cfg.DATASETS.TEST), len(evaluators))

results = OrderedDict()
for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
data_loader = cls.build_test_loader(cfg, dataset_name, model.device)
# When evaluators are passed in as arguments,
# implicitly assume that evaluators can be created before data_loader.
if evaluators is not None:
evaluator = evaluators[idx]
else:
try:
evaluator = cls.build_evaluator(cfg, dataset_name)
except NotImplementedError:
logger.warn(
"No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
"or implement its `build_evaluator` method."
)
results[dataset_name] = {}
continue
results_i = inference_on_dataset(model, data_loader, evaluator)
results[dataset_name] = results_i
if comm.is_main_process():
assert isinstance(results_i, dict), "Evaluator must return a dict on the main process. Got {} instead.".format(
results_i
)
logger.info("Evaluation results for {} in csv format:".format(dataset_name))
print_csv_format(results_i)

if len(results) == 1:
results = list(results.values())[0]
return results

def validate(self):
results = self.test(self.cfg, self.model)
Loading

0 comments on commit 60cd4a0

Please sign in to comment.