From e5ca8ef8cdb6c1f77790df037712c21d75551ec7 Mon Sep 17 00:00:00 2001 From: Zhi Tian Date: Fri, 3 Jul 2020 21:17:46 +0930 Subject: [PATCH] add CondInst --- adet/config/defaults.py | 23 ++ adet/modeling/__init__.py | 1 + adet/modeling/backbone/__init__.py | 1 + adet/modeling/backbone/bifpn.py | 397 ++++++++++++++++++++ adet/modeling/condinst/__init__.py | 1 + adet/modeling/condinst/condinst.py | 209 +++++++++++ adet/modeling/condinst/dynamic_mask_head.py | 170 +++++++++ adet/modeling/condinst/mask_branch.py | 138 +++++++ adet/modeling/fcos/fcos.py | 4 + adet/utils/comm.py | 2 +- configs/CondInst/Base-CondInst.yaml | 26 ++ configs/CondInst/MS_R_101_1x.yaml | 6 + configs/CondInst/MS_R_101_3x.yaml | 9 + configs/CondInst/MS_R_101_3x_sem.yaml | 12 + configs/CondInst/MS_R_101_BiFPN_3x.yaml | 15 + configs/CondInst/MS_R_101_BiFPN_3x_sem.yaml | 18 + configs/CondInst/MS_R_50_1x.yaml | 6 + configs/CondInst/MS_R_50_3x.yaml | 9 + configs/CondInst/MS_R_50_3x_sem.yaml | 12 + configs/CondInst/MS_R_50_BiFPN_1x.yaml | 12 + configs/CondInst/MS_R_50_BiFPN_3x.yaml | 15 + configs/CondInst/MS_R_50_BiFPN_3x_sem.yaml | 18 + configs/CondInst/README.md | 83 ++++ 23 files changed, 1186 insertions(+), 1 deletion(-) create mode 100644 adet/modeling/backbone/bifpn.py create mode 100644 adet/modeling/condinst/__init__.py create mode 100644 adet/modeling/condinst/condinst.py create mode 100644 adet/modeling/condinst/dynamic_mask_head.py create mode 100644 adet/modeling/condinst/mask_branch.py create mode 100644 configs/CondInst/Base-CondInst.yaml create mode 100644 configs/CondInst/MS_R_101_1x.yaml create mode 100644 configs/CondInst/MS_R_101_3x.yaml create mode 100644 configs/CondInst/MS_R_101_3x_sem.yaml create mode 100644 configs/CondInst/MS_R_101_BiFPN_3x.yaml create mode 100644 configs/CondInst/MS_R_101_BiFPN_3x_sem.yaml create mode 100644 configs/CondInst/MS_R_50_1x.yaml create mode 100644 configs/CondInst/MS_R_50_3x.yaml create mode 100644 configs/CondInst/MS_R_50_3x_sem.yaml create mode 100644 configs/CondInst/MS_R_50_BiFPN_1x.yaml create mode 100644 configs/CondInst/MS_R_50_BiFPN_3x.yaml create mode 100644 configs/CondInst/MS_R_50_BiFPN_3x_sem.yaml create mode 100644 configs/CondInst/README.md diff --git a/adet/config/defaults.py b/adet/config/defaults.py index 9bcbc814c..fa30e4fa8 100644 --- a/adet/config/defaults.py +++ b/adet/config/defaults.py @@ -199,6 +199,29 @@ # Whether to compute loss on original mask (binary mask). _C.MODEL.MEInst.LOSS_ON_MASK = False +# ---------------------------------------------------------------------------- # +# CondInst Options +# ---------------------------------------------------------------------------- # +_C.MODEL.CONDINST = CN() + +# the downsampling ratio of the final instance masks to the input image +_C.MODEL.CONDINST.MASK_OUT_STRIDE = 4 +_C.MODEL.CONDINST.MAX_PROPOSALS = -1 + +_C.MODEL.CONDINST.MASK_HEAD = CN() +_C.MODEL.CONDINST.MASK_HEAD.CHANNELS = 8 +_C.MODEL.CONDINST.MASK_HEAD.NUM_LAYERS = 3 +_C.MODEL.CONDINST.MASK_HEAD.USE_FP16 = False +_C.MODEL.CONDINST.MASK_HEAD.DISABLE_REL_COORDS = False + +_C.MODEL.CONDINST.MASK_BRANCH = CN() +_C.MODEL.CONDINST.MASK_BRANCH.OUT_CHANNELS = 8 +_C.MODEL.CONDINST.MASK_BRANCH.IN_FEATURES = ["p3", "p4", "p5"] +_C.MODEL.CONDINST.MASK_BRANCH.CHANNELS = 128 +_C.MODEL.CONDINST.MASK_BRANCH.NORM = "BN" +_C.MODEL.CONDINST.MASK_BRANCH.NUM_CONVS = 4 +_C.MODEL.CONDINST.MASK_BRANCH.SEMANTIC_LOSS_ON = False + # ---------------------------------------------------------------------------- # # TOP Module Options # ---------------------------------------------------------------------------- # diff --git a/adet/modeling/__init__.py b/adet/modeling/__init__.py index b261d6839..c8224c86b 100644 --- a/adet/modeling/__init__.py +++ b/adet/modeling/__init__.py @@ -6,6 +6,7 @@ from .roi_heads.text_head import TextHead from .batext import BAText from .MEInst import MEInst +from .condinst import condinst _EXCLUDE = {"torch", "ShapeSpec"} __all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")] diff --git a/adet/modeling/backbone/__init__.py b/adet/modeling/backbone/__init__.py index 2edb0a660..93af74465 100644 --- a/adet/modeling/backbone/__init__.py +++ b/adet/modeling/backbone/__init__.py @@ -2,3 +2,4 @@ from .vovnet import build_vovnet_fpn_backbone, build_vovnet_backbone from .dla import build_fcos_dla_fpn_backbone from .resnet_lpf import build_resnet_lpf_backbone +from .bifpn import build_fcos_resnet_bifpn_backbone diff --git a/adet/modeling/backbone/bifpn.py b/adet/modeling/backbone/bifpn.py new file mode 100644 index 000000000..8cb3446d5 --- /dev/null +++ b/adet/modeling/backbone/bifpn.py @@ -0,0 +1,397 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from detectron2.layers import Conv2d, ShapeSpec, get_norm + +from detectron2.modeling.backbone import Backbone, build_resnet_backbone +from detectron2.modeling import BACKBONE_REGISTRY +from .mobilenet import build_mnv2_backbone + +__all__ = [] + + +def swish(x): + return x * x.sigmoid() + + +def split_name(name): + for i, c in enumerate(name): + if not c.isalpha(): + return name[:i], int(name[i:]) + raise ValueError() + + +class FeatureMapResampler(nn.Module): + def __init__(self, in_channels, out_channels, stride, norm=""): + super(FeatureMapResampler, self).__init__() + if in_channels != out_channels: + self.reduction = Conv2d( + in_channels, out_channels, kernel_size=1, + bias=(norm == ""), + norm=get_norm(norm, out_channels), + activation=None + ) + else: + self.reduction = None + + assert stride <= 2 + self.stride = stride + + def forward(self, x): + if self.reduction is not None: + x = self.reduction(x) + + if self.stride == 2: + x = F.max_pool2d( + x, kernel_size=self.stride + 1, + stride=self.stride, padding=1 + ) + elif self.stride == 1: + pass + else: + raise NotImplementedError() + return x + + +class BackboneWithTopLevels(Backbone): + def __init__(self, backbone, out_channels, num_top_levels, norm=""): + super(BackboneWithTopLevels, self).__init__() + self.backbone = backbone + backbone_output_shape = backbone.output_shape() + + self._out_feature_channels = {name: shape.channels for name, shape in backbone_output_shape.items()} + self._out_feature_strides = {name: shape.stride for name, shape in backbone_output_shape.items()} + self._out_features = list(self._out_feature_strides.keys()) + + last_feature_name = max(self._out_feature_strides.keys(), key=lambda x: split_name(x)[1]) + self.last_feature_name = last_feature_name + self.num_top_levels = num_top_levels + + last_channels = self._out_feature_channels[last_feature_name] + last_stride = self._out_feature_strides[last_feature_name] + + prefix, suffix = split_name(last_feature_name) + prev_channels = last_channels + for i in range(num_top_levels): + name = prefix + str(suffix + i + 1) + self.add_module(name, FeatureMapResampler( + prev_channels, out_channels, 2, norm + )) + prev_channels = out_channels + + self._out_feature_channels[name] = out_channels + self._out_feature_strides[name] = last_stride * 2 ** (i + 1) + self._out_features.append(name) + + def forward(self, x): + outputs = self.backbone(x) + last_features = outputs[self.last_feature_name] + prefix, suffix = split_name(self.last_feature_name) + + x = last_features + for i in range(self.num_top_levels): + name = prefix + str(suffix + i + 1) + x = self.__getattr__(name)(x) + outputs[name] = x + + return outputs + + +class SingleBiFPN(Backbone): + """ + This module implements Feature Pyramid Network. + It creates pyramid features built on top of some input feature maps. + """ + + def __init__( + self, in_channels_list, out_channels, norm="" + ): + """ + Args: + bottom_up (Backbone): module representing the bottom up subnetwork. + Must be a subclass of :class:`Backbone`. The multi-scale feature + maps generated by the bottom up network, and listed in `in_features`, + are used to generate FPN levels. + in_features (list[str]): names of the input feature maps coming + from the backbone to which FPN is attached. For example, if the + backbone produces ["res2", "res3", "res4"], any *contiguous* sublist + of these may be used; order must be from high to low resolution. + out_channels (int): number of channels in the output feature maps. + norm (str): the normalization to use. + """ + super(SingleBiFPN, self).__init__() + + self.out_channels = out_channels + # build 5-levels bifpn + if len(in_channels_list) == 5: + self.nodes = [ + {'feat_level': 3, 'inputs_offsets': [3, 4]}, + {'feat_level': 2, 'inputs_offsets': [2, 5]}, + {'feat_level': 1, 'inputs_offsets': [1, 6]}, + {'feat_level': 0, 'inputs_offsets': [0, 7]}, + {'feat_level': 1, 'inputs_offsets': [1, 7, 8]}, + {'feat_level': 2, 'inputs_offsets': [2, 6, 9]}, + {'feat_level': 3, 'inputs_offsets': [3, 5, 10]}, + {'feat_level': 4, 'inputs_offsets': [4, 11]}, + ] + elif len(in_channels_list) == 3: + self.nodes = [ + {'feat_level': 1, 'inputs_offsets': [1, 2]}, + {'feat_level': 0, 'inputs_offsets': [0, 3]}, + {'feat_level': 1, 'inputs_offsets': [1, 3, 4]}, + {'feat_level': 2, 'inputs_offsets': [2, 5]}, + ] + else: + raise NotImplementedError + + node_info = [_ for _ in in_channels_list] + + num_output_connections = [0 for _ in in_channels_list] + for fnode in self.nodes: + feat_level = fnode["feat_level"] + inputs_offsets = fnode["inputs_offsets"] + inputs_offsets_str = "_".join(map(str, inputs_offsets)) + for input_offset in inputs_offsets: + num_output_connections[input_offset] += 1 + + in_channels = node_info[input_offset] + if in_channels != out_channels: + lateral_conv = Conv2d( + in_channels, + out_channels, + kernel_size=1, + norm=get_norm(norm, out_channels) + ) + self.add_module( + "lateral_{}_f{}".format(input_offset, feat_level), lateral_conv + ) + node_info.append(out_channels) + num_output_connections.append(0) + + # generate attention weights + name = "weights_f{}_{}".format(feat_level, inputs_offsets_str) + self.__setattr__(name, nn.Parameter( + torch.ones(len(inputs_offsets), dtype=torch.float32), + requires_grad=True + )) + + # generate convolutions after combination + name = "outputs_f{}_{}".format(feat_level, inputs_offsets_str) + self.add_module(name, Conv2d( + out_channels, + out_channels, + kernel_size=3, + padding=1, + norm=get_norm(norm, out_channels), + bias=(norm == "") + )) + + def forward(self, feats): + """ + Args: + input (dict[str->Tensor]): mapping feature map name (e.g., "p5") to + feature map tensor for each feature level in high to low resolution order. + + Returns: + dict[str->Tensor]: + mapping from feature map name to FPN feature map tensor + in high to low resolution order. Returned feature names follow the FPN + paper convention: "p", where stage has stride = 2 ** stage e.g., + ["n2", "n3", ..., "n6"]. + """ + feats = [_ for _ in feats] + num_levels = len(feats) + num_output_connections = [0 for _ in feats] + for fnode in self.nodes: + feat_level = fnode["feat_level"] + inputs_offsets = fnode["inputs_offsets"] + inputs_offsets_str = "_".join(map(str, inputs_offsets)) + input_nodes = [] + _, _, target_h, target_w = feats[feat_level].size() + for input_offset in inputs_offsets: + num_output_connections[input_offset] += 1 + input_node = feats[input_offset] + + # reduction + if input_node.size(1) != self.out_channels: + name = "lateral_{}_f{}".format(input_offset, feat_level) + input_node = self.__getattr__(name)(input_node) + + # maybe downsample + _, _, h, w = input_node.size() + if h > target_h and w > target_w: + height_stride_size = int((h - 1) // target_h + 1) + width_stride_size = int((w - 1) // target_w + 1) + assert height_stride_size == width_stride_size == 2 + input_node = F.max_pool2d( + input_node, kernel_size=(height_stride_size + 1, width_stride_size + 1), + stride=(height_stride_size, width_stride_size), padding=1 + ) + elif h <= target_h and w <= target_w: + if h < target_h or w < target_w: + input_node = F.interpolate( + input_node, + size=(target_h, target_w), + mode="nearest" + ) + else: + raise NotImplementedError() + input_nodes.append(input_node) + + # attention + name = "weights_f{}_{}".format(feat_level, inputs_offsets_str) + weights = F.relu(self.__getattr__(name)) + norm_weights = weights / (weights.sum() + 0.0001) + + new_node = torch.stack(input_nodes, dim=-1) + new_node = (norm_weights * new_node).sum(dim=-1) + new_node = swish(new_node) + + name = "outputs_f{}_{}".format(feat_level, inputs_offsets_str) + feats.append(self.__getattr__(name)(new_node)) + + num_output_connections.append(0) + + output_feats = [] + for idx in range(num_levels): + for i, fnode in enumerate(reversed(self.nodes)): + if fnode['feat_level'] == idx: + output_feats.append(feats[-1 - i]) + break + else: + raise ValueError() + return output_feats + + +class BiFPN(Backbone): + """ + This module implements Feature Pyramid Network. + It creates pyramid features built on top of some input feature maps. + """ + + def __init__( + self, bottom_up, in_features, out_channels, num_top_levels, num_repeats, norm="" + ): + """ + Args: + bottom_up (Backbone): module representing the bottom up subnetwork. + Must be a subclass of :class:`Backbone`. The multi-scale feature + maps generated by the bottom up network, and listed in `in_features`, + are used to generate FPN levels. + in_features (list[str]): names of the input feature maps coming + from the backbone to which FPN is attached. For example, if the + backbone produces ["res2", "res3", "res4"], any *contiguous* sublist + of these may be used; order must be from high to low resolution. + out_channels (int): number of channels in the output feature maps. + num_top_levels (int): the number of the top levels (p6 or p7). + num_repeats (int): the number of repeats of BiFPN. + norm (str): the normalization to use. + """ + super(BiFPN, self).__init__() + assert isinstance(bottom_up, Backbone) + + # add extra feature levels (i.e., 6 and 7) + self.bottom_up = BackboneWithTopLevels( + bottom_up, out_channels, + num_top_levels, norm + ) + bottom_up_output_shapes = self.bottom_up.output_shape() + + in_features = sorted(in_features, key=lambda x: split_name(x)[1]) + self._size_divisibility = bottom_up_output_shapes[in_features[-1]].stride + self.out_channels = out_channels + self.min_level = split_name(in_features[0])[1] + + # add the names for top blocks + prefix, last_suffix = split_name(in_features[-1]) + for i in range(num_top_levels): + in_features.append(prefix + str(last_suffix + i + 1)) + self.in_features = in_features + + # generate output features + self._out_features = ["p{}".format(split_name(name)[1]) for name in in_features] + self._out_feature_strides = { + out_name: bottom_up_output_shapes[in_name].stride + for out_name, in_name in zip(self._out_features, in_features) + } + self._out_feature_channels = {k: out_channels for k in self._out_features} + + # build bifpn + self.repeated_bifpn = nn.ModuleList() + for i in range(num_repeats): + if i == 0: + in_channels_list = [ + bottom_up_output_shapes[name].channels for name in in_features + ] + else: + in_channels_list = [ + self._out_feature_channels[name] for name in self._out_features + ] + self.repeated_bifpn.append(SingleBiFPN( + in_channels_list, out_channels, norm + )) + + @property + def size_divisibility(self): + return self._size_divisibility + + def forward(self, x): + """ + Args: + input (dict[str->Tensor]): mapping feature map name (e.g., "p5") to + feature map tensor for each feature level in high to low resolution order. + + Returns: + dict[str->Tensor]: + mapping from feature map name to FPN feature map tensor + in high to low resolution order. Returned feature names follow the FPN + paper convention: "p", where stage has stride = 2 ** stage e.g., + ["n2", "n3", ..., "n6"]. + """ + bottom_up_features = self.bottom_up(x) + feats = [bottom_up_features[f] for f in self.in_features] + + for bifpn in self.repeated_bifpn: + feats = bifpn(feats) + + return dict(zip(self._out_features, feats)) + + +def _assert_strides_are_log2_contiguous(strides): + """ + Assert that each stride is 2x times its preceding stride, i.e. "contiguous in log2". + """ + for i, stride in enumerate(strides[1:], 1): + assert stride == 2 * strides[i - 1], "Strides {} {} are not log2 contiguous".format( + stride, strides[i - 1] + ) + + +@BACKBONE_REGISTRY.register() +def build_fcos_resnet_bifpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: + cfg: a detectron2 CfgNode + + Returns: + backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. + """ + if cfg.MODEL.MOBILENET: + bottom_up = build_mnv2_backbone(cfg, input_shape) + else: + bottom_up = build_resnet_backbone(cfg, input_shape) + in_features = cfg.MODEL.BiFPN.IN_FEATURES + out_channels = cfg.MODEL.BiFPN.OUT_CHANNELS + num_repeats = cfg.MODEL.BiFPN.NUM_REPEATS + top_levels = cfg.MODEL.FCOS.TOP_LEVELS + + backbone = BiFPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + num_top_levels=top_levels, + num_repeats=num_repeats, + norm=cfg.MODEL.BiFPN.NORM + ) + return backbone diff --git a/adet/modeling/condinst/__init__.py b/adet/modeling/condinst/__init__.py new file mode 100644 index 000000000..395e19fcf --- /dev/null +++ b/adet/modeling/condinst/__init__.py @@ -0,0 +1 @@ +from .condinst import CondInst diff --git a/adet/modeling/condinst/condinst.py b/adet/modeling/condinst/condinst.py new file mode 100644 index 000000000..4eba13ae0 --- /dev/null +++ b/adet/modeling/condinst/condinst.py @@ -0,0 +1,209 @@ +# -*- coding: utf-8 -*- +import logging + +import torch +from torch import nn +import torch.nn.functional as F + +from detectron2.structures import ImageList +from detectron2.modeling.proposal_generator import build_proposal_generator +from detectron2.modeling.backbone import build_backbone +from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY +from detectron2.structures.instances import Instances +from detectron2.structures.masks import polygons_to_bitmask + +from .dynamic_mask_head import build_dynamic_mask_head +from .mask_branch import build_mask_branch + +from adet.utils.comm import aligned_bilinear + +__all__ = ["CondInst"] + + +logger = logging.getLogger(__name__) + + +@META_ARCH_REGISTRY.register() +class CondInst(nn.Module): + """ + Main class for CondInst architectures (see https://arxiv.org/abs/2003.05664). + """ + + def __init__(self, cfg): + super().__init__() + self.device = torch.device(cfg.MODEL.DEVICE) + + self.backbone = build_backbone(cfg) + self.proposal_generator = build_proposal_generator(cfg, self.backbone.output_shape()) + self.mask_head = build_dynamic_mask_head(cfg) + self.mask_branch = build_mask_branch(cfg, self.backbone.output_shape()) + self.mask_out_stride = cfg.MODEL.CONDINST.MASK_OUT_STRIDE + self.max_proposals = cfg.MODEL.CONDINST.MAX_PROPOSALS + + # build top module + in_channels = self.proposal_generator.in_channels_to_top_module + + self.controller = nn.Conv2d( + in_channels, self.mask_head.num_gen_params, + kernel_size=3, stride=1, padding=1 + ) + torch.nn.init.normal_(self.controller.weight, std=0.01) + torch.nn.init.constant_(self.controller.bias, 0) + + pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1) + pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1) + self.normalizer = lambda x: (x - pixel_mean) / pixel_std + self.to(self.device) + + def forward(self, batched_inputs): + images = [x["image"].to(self.device) for x in batched_inputs] + images = [self.normalizer(x) for x in images] + images = ImageList.from_tensors(images, self.backbone.size_divisibility) + features = self.backbone(images.tensor) + + if "instances" in batched_inputs[0]: + gt_instances = [x["instances"].to(self.device) for x in batched_inputs] + self.add_bitmasks(gt_instances, images.tensor.size(-2), images.tensor.size(-1)) + else: + gt_instances = None + + mask_feats, sem_losses = self.mask_branch(features, gt_instances) + + proposals, proposal_losses = self.proposal_generator( + images, features, gt_instances, self.controller + ) + + if self.training: + loss_mask = self._forward_mask_heads_train(proposals, mask_feats, gt_instances) + + losses = {} + losses.update(sem_losses) + losses.update(proposal_losses) + losses.update({"loss_mask": loss_mask}) + return losses + else: + pred_instances_w_masks = self._forward_mask_heads_test(proposals, mask_feats) + + padded_im_h, padded_im_w = images.tensor.size()[-2:] + processed_results = [] + for im_id, (input_per_image, image_size) in enumerate(zip(batched_inputs, images.image_sizes)): + height = input_per_image.get("height", image_size[0]) + width = input_per_image.get("width", image_size[1]) + + instances_per_im = pred_instances_w_masks[pred_instances_w_masks.im_inds == im_id] + instances_per_im = self.postprocess( + instances_per_im, height, width, + padded_im_h, padded_im_w + ) + + processed_results.append({ + "instances": instances_per_im + }) + + return processed_results + + def _forward_mask_heads_train(self, proposals, mask_feats, gt_instances): + # prepare the inputs for mask heads + pred_instances = proposals["instances"] + + if 0 <= self.max_proposals < len(pred_instances): + inds = torch.randperm(len(pred_instances), device=mask_feats.device).long() + logger.info("clipping proposals from {} to {}".format( + len(pred_instances), self.max_proposals + )) + pred_instances = pred_instances[inds[:self.max_proposals]] + + pred_instances.mask_head_params = pred_instances.top_feats + + loss_mask = self.mask_head( + mask_feats, self.mask_branch.out_stride, + pred_instances, gt_instances + ) + + return loss_mask + + def _forward_mask_heads_test(self, proposals, mask_feats): + # prepare the inputs for mask heads + for im_id, per_im in enumerate(proposals): + per_im.im_inds = per_im.locations.new_ones(len(per_im), dtype=torch.long) * im_id + pred_instances = Instances.cat(proposals) + pred_instances.mask_head_params = pred_instances.top_feat + + pred_instances_w_masks = self.mask_head( + mask_feats, self.mask_branch.out_stride, pred_instances + ) + + return pred_instances_w_masks + + def add_bitmasks(self, instances, im_h, im_w): + for per_im_gt_inst in instances: + if not per_im_gt_inst.has("gt_masks"): + continue + polygons = per_im_gt_inst.get("gt_masks").polygons + per_im_bitmasks = [] + per_im_bitmasks_full = [] + for per_polygons in polygons: + bitmask = polygons_to_bitmask(per_polygons, im_h, im_w) + bitmask = torch.from_numpy(bitmask).to(self.device).float() + start = int(self.mask_out_stride // 2) + bitmask_full = bitmask.clone() + bitmask = bitmask[start::self.mask_out_stride, start::self.mask_out_stride] + + assert bitmask.size(0) * self.mask_out_stride == im_h + assert bitmask.size(1) * self.mask_out_stride == im_w + + per_im_bitmasks.append(bitmask) + per_im_bitmasks_full.append(bitmask_full) + + per_im_gt_inst.gt_bitmasks = torch.stack(per_im_bitmasks, dim=0) + per_im_gt_inst.gt_bitmasks_full = torch.stack(per_im_bitmasks_full, dim=0) + + def postprocess(self, results, output_height, output_width, padded_im_h, padded_im_w, mask_threshold=0.5): + """ + Resize the output instances. + The input images are often resized when entering an object detector. + As a result, we often need the outputs of the detector in a different + resolution from its inputs. + This function will resize the raw outputs of an R-CNN detector + to produce outputs according to the desired output resolution. + Args: + results (Instances): the raw outputs from the detector. + `results.image_size` contains the input image resolution the detector sees. + This object might be modified in-place. + output_height, output_width: the desired output resolution. + Returns: + Instances: the resized output from the model, based on the output resolution + """ + scale_x, scale_y = (output_width / results.image_size[1], output_height / results.image_size[0]) + resized_im_h, resized_im_w = results.image_size + results = Instances((output_height, output_width), **results.get_fields()) + + if results.has("pred_boxes"): + output_boxes = results.pred_boxes + elif results.has("proposal_boxes"): + output_boxes = results.proposal_boxes + + output_boxes.scale(scale_x, scale_y) + output_boxes.clip(results.image_size) + + results = results[output_boxes.nonempty()] + + if results.has("pred_global_masks"): + mask_h, mask_w = results.pred_global_masks.size()[-2:] + factor_h = padded_im_h // mask_h + factor_w = padded_im_w // mask_w + assert factor_h == factor_w + factor = factor_h + pred_global_masks = aligned_bilinear( + results.pred_global_masks, factor + ) + pred_global_masks = pred_global_masks[:, :, :resized_im_h, :resized_im_w] + pred_global_masks = F.interpolate( + pred_global_masks, + size=(output_height, output_width), + mode="bilinear", align_corners=False + ) + pred_global_masks = pred_global_masks[:, 0, :, :] + results.pred_masks = (pred_global_masks > mask_threshold).float() + + return results diff --git a/adet/modeling/condinst/dynamic_mask_head.py b/adet/modeling/condinst/dynamic_mask_head.py new file mode 100644 index 000000000..64e7a135a --- /dev/null +++ b/adet/modeling/condinst/dynamic_mask_head.py @@ -0,0 +1,170 @@ +import torch +from torch.nn import functional as F +from torch import nn + +from adet.utils.comm import compute_locations, aligned_bilinear + + +def dice_coefficient(x, target): + eps = 1e-5 + n_inst = x.size(0) + x = x.reshape(n_inst, -1) + target = target.reshape(n_inst, -1) + intersection = (x * target).sum(dim=1) + union = (x ** 2.0).sum(dim=1) + (target ** 2.0).sum(dim=1) + eps + loss = 1. - (2 * intersection / union) + return loss + + +def parse_dynamic_params(params, channels, weight_nums, bias_nums): + assert params.dim() == 2 + assert len(weight_nums) == len(bias_nums) + assert params.size(1) == sum(weight_nums) + sum(bias_nums) + + num_insts = params.size(0) + num_layers = len(weight_nums) + + params_splits = list(torch.split_with_sizes( + params, weight_nums + bias_nums, dim=1 + )) + + weight_splits = params_splits[:num_layers] + bias_splits = params_splits[num_layers:] + + for l in range(num_layers): + if l < num_layers - 1: + # out_channels x in_channels x 1 x 1 + weight_splits[l] = weight_splits[l].reshape(num_insts * channels, -1, 1, 1) + bias_splits[l] = bias_splits[l].reshape(num_insts * channels) + else: + # out_channels x in_channels x 1 x 1 + weight_splits[l] = weight_splits[l].reshape(num_insts * 1, -1, 1, 1) + bias_splits[l] = bias_splits[l].reshape(num_insts) + + return weight_splits, bias_splits + + +def build_dynamic_mask_head(cfg): + return DynamicMaskHead(cfg) + + +class DynamicMaskHead(nn.Module): + def __init__(self, cfg): + super(DynamicMaskHead, self).__init__() + self.num_layers = cfg.MODEL.CONDINST.MASK_HEAD.NUM_LAYERS + self.channels = cfg.MODEL.CONDINST.MASK_HEAD.CHANNELS + self.in_channels = cfg.MODEL.CONDINST.MASK_BRANCH.OUT_CHANNELS + self.mask_out_stride = cfg.MODEL.CONDINST.MASK_OUT_STRIDE + self.disable_rel_coords = cfg.MODEL.CONDINST.MASK_HEAD.DISABLE_REL_COORDS + + soi = cfg.MODEL.FCOS.SIZES_OF_INTEREST + self.register_buffer("sizes_of_interest", torch.tensor(soi + [soi[-1] * 2])) + + weight_nums, bias_nums = [], [] + for l in range(self.num_layers): + if l == 0: + if not self.disable_rel_coords: + weight_nums.append((self.in_channels + 2) * self.channels) + else: + weight_nums.append(self.in_channels * self.channels) + bias_nums.append(self.channels) + elif l == self.num_layers - 1: + weight_nums.append(self.channels * 1) + bias_nums.append(1) + else: + weight_nums.append(self.channels * self.channels) + bias_nums.append(self.channels) + + self.weight_nums = weight_nums + self.bias_nums = bias_nums + self.num_gen_params = sum(weight_nums) + sum(bias_nums) + + def mask_heads_forward(self, features, weights, biases, num_insts): + ''' + :param features + :param weights: [w0, w1, ...] + :param bias: [b0, b1, ...] + :return: + ''' + assert features.dim() == 4 + n_layers = len(weights) + x = features + for i, (w, b) in enumerate(zip(weights, biases)): + x = F.conv2d( + x, w, bias=b, + stride=1, padding=0, + groups=num_insts + ) + if i < n_layers - 1: + x = F.relu(x) + return x + + def mask_heads_forward_with_coords( + self, mask_feats, mask_feat_stride, instances + ): + locations = compute_locations( + mask_feats.size(2), mask_feats.size(3), + stride=mask_feat_stride, device=mask_feats.device + ) + n_inst = len(instances) + + im_inds = instances.im_inds + mask_head_params = instances.mask_head_params + + N, _, H, W = mask_feats.size() + + if not self.disable_rel_coords: + instance_locations = instances.locations + relative_coords = instance_locations.reshape(-1, 1, 2) - locations.reshape(1, -1, 2) + relative_coords = relative_coords.permute(0, 2, 1).float() + soi = self.sizes_of_interest.float()[instances.fpn_levels] + relative_coords = relative_coords / soi.reshape(-1, 1, 1) + relative_coords = relative_coords.to(dtype=mask_feats.dtype) + + mask_head_inputs = torch.cat([ + relative_coords, mask_feats[im_inds].reshape(n_inst, self.in_channels, H * W) + ], dim=1) + else: + mask_head_inputs = mask_feats[im_inds].reshape(n_inst, self.in_channels, H * W) + + mask_head_inputs = mask_head_inputs.reshape(1, -1, H, W) + + weights, biases = parse_dynamic_params( + mask_head_params, self.channels, + self.weight_nums, self.bias_nums + ) + + mask_logits = self.mask_heads_forward(mask_head_inputs, weights, biases, n_inst) + + mask_logits = mask_logits.reshape(-1, 1, H, W) + + assert mask_feat_stride >= self.mask_out_stride + assert mask_feat_stride % self.mask_out_stride == 0 + mask_logits = aligned_bilinear(mask_logits, int(mask_feat_stride / self.mask_out_stride)) + + return mask_logits.sigmoid() + + def __call__(self, mask_feats, mask_feat_stride, pred_instances, gt_instances=None): + if self.training: + gt_inds = pred_instances.gt_inds + gt_bitmasks = torch.cat([per_im.gt_bitmasks for per_im in gt_instances]) + gt_bitmasks = gt_bitmasks[gt_inds].unsqueeze(dim=1).to(dtype=mask_feats.dtype) + + if len(pred_instances) == 0: + loss_mask = mask_feats.sum() * 0 + pred_instances.mask_head_params.sum() * 0 + else: + mask_scores = self.mask_heads_forward_with_coords( + mask_feats, mask_feat_stride, pred_instances + ) + mask_losses = dice_coefficient(mask_scores, gt_bitmasks) + loss_mask = mask_losses.mean() + + return loss_mask.float() + else: + if len(pred_instances) > 0: + mask_scores = self.mask_heads_forward_with_coords( + mask_feats, mask_feat_stride, pred_instances + ) + pred_instances.pred_global_masks = mask_scores.float() + + return pred_instances diff --git a/adet/modeling/condinst/mask_branch.py b/adet/modeling/condinst/mask_branch.py new file mode 100644 index 000000000..bb15fcb43 --- /dev/null +++ b/adet/modeling/condinst/mask_branch.py @@ -0,0 +1,138 @@ +from typing import Dict +import math + +import torch +from torch import nn + +from fvcore.nn import sigmoid_focal_loss_jit +from detectron2.layers import ShapeSpec + +from adet.layers import conv_with_kaiming_uniform +from adet.utils.comm import aligned_bilinear + + +INF = 100000000 + + +def build_mask_branch(cfg, input_shape): + return MaskBranch(cfg, input_shape) + + +class MaskBranch(nn.Module): + def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]): + super().__init__() + self.in_features = cfg.MODEL.CONDINST.MASK_BRANCH.IN_FEATURES + self.sem_loss_on = cfg.MODEL.CONDINST.MASK_BRANCH.SEMANTIC_LOSS_ON + self.num_outputs = cfg.MODEL.CONDINST.MASK_BRANCH.OUT_CHANNELS + norm = cfg.MODEL.CONDINST.MASK_BRANCH.NORM + num_convs = cfg.MODEL.CONDINST.MASK_BRANCH.NUM_CONVS + channels = cfg.MODEL.CONDINST.MASK_BRANCH.CHANNELS + self.out_stride = input_shape[self.in_features[0]].stride + + feature_channels = {k: v.channels for k, v in input_shape.items()} + + conv_block = conv_with_kaiming_uniform(norm, activation=True) + + self.refine = nn.ModuleList() + for in_feature in self.in_features: + self.refine.append(conv_block( + feature_channels[in_feature], + channels, 3, 1 + )) + + tower = [] + for i in range(num_convs): + tower.append(conv_block( + channels, channels, 3, 1 + )) + tower.append(nn.Conv2d( + channels, max(self.num_outputs, 1), 1 + )) + self.add_module('tower', nn.Sequential(*tower)) + + if self.sem_loss_on: + num_classes = cfg.MODEL.FCOS.NUM_CLASSES + self.focal_loss_alpha = cfg.MODEL.FCOS.LOSS_ALPHA + self.focal_loss_gamma = cfg.MODEL.FCOS.LOSS_GAMMA + + in_channels = feature_channels[self.in_features[0]] + self.seg_head = nn.Sequential( + conv_block(in_channels, channels, kernel_size=3, stride=1), + conv_block(channels, channels, kernel_size=3, stride=1) + ) + + self.logits = nn.Conv2d(channels, num_classes, kernel_size=1, stride=1) + + prior_prob = cfg.MODEL.FCOS.PRIOR_PROB + bias_value = -math.log((1 - prior_prob) / prior_prob) + torch.nn.init.constant_(self.logits.bias, bias_value) + + def forward(self, features, gt_instances=None): + for i, f in enumerate(self.in_features): + if i == 0: + x = self.refine[i](features[f]) + else: + x_p = self.refine[i](features[f]) + + target_h, target_w = x.size()[2:] + h, w = x_p.size()[2:] + assert target_h % h == 0 + assert target_w % w == 0 + factor_h, factor_w = target_h // h, target_w // w + assert factor_h == factor_w + x_p = aligned_bilinear(x_p, factor_h) + x = x + x_p + + mask_feats = self.tower(x) + + if self.num_outputs == 0: + mask_feats = mask_feats[:, :self.num_outputs] + + losses = {} + # auxiliary thing semantic loss + if self.training and self.sem_loss_on: + logits_pred = self.logits(self.seg_head( + features[self.in_features[0]] + )) + + # compute semantic targets + semantic_targets = [] + for per_im_gt in gt_instances: + h, w = per_im_gt.gt_bitmasks_full.size()[-2:] + areas = per_im_gt.gt_bitmasks_full.sum(dim=-1).sum(dim=-1) + areas = areas[:, None, None].repeat(1, h, w) + areas[per_im_gt.gt_bitmasks_full == 0] = INF + areas = areas.permute(1, 2, 0).reshape(h * w, -1) + min_areas, inds = areas.min(dim=1) + per_im_sematic_targets = per_im_gt.gt_classes[inds] + 1 + per_im_sematic_targets[min_areas == INF] = 0 + per_im_sematic_targets = per_im_sematic_targets.reshape(h, w) + semantic_targets.append(per_im_sematic_targets) + + semantic_targets = torch.stack(semantic_targets, dim=0) + + # resize target to reduce memory + semantic_targets = semantic_targets[ + :, None, self.out_stride // 2::self.out_stride, + self.out_stride // 2::self.out_stride + ] + + # prepare one-hot targets + num_classes = logits_pred.size(1) + class_range = torch.arange( + num_classes, dtype=logits_pred.dtype, + device=logits_pred.device + )[:, None, None] + class_range = class_range + 1 + one_hot = (semantic_targets == class_range).float() + num_pos = (one_hot > 0).sum().float().clamp(min=1.0) + + loss_sem = sigmoid_focal_loss_jit( + logits_pred, one_hot, + alpha=self.focal_loss_alpha, + gamma=self.focal_loss_gamma, + reduction="sum", + ) / num_pos + losses['loss_sem'] = loss_sem + + return mask_feats, losses diff --git a/adet/modeling/fcos/fcos.py b/adet/modeling/fcos/fcos.py index 73bc3d257..4591daae7 100644 --- a/adet/modeling/fcos/fcos.py +++ b/adet/modeling/fcos/fcos.py @@ -51,6 +51,8 @@ def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]): self.yield_proposal = cfg.MODEL.FCOS.YIELD_PROPOSAL self.fcos_head = FCOSHead(cfg, [input_shape[f] for f in self.in_features]) + self.in_channels_to_top_module = self.fcos_head.in_channels_to_top_module + self.fcos_outputs = FCOSOutputs(cfg) def forward_head(self, features, top_module=None): @@ -140,6 +142,8 @@ def __init__(self, cfg, input_shape: List[ShapeSpec]): assert len(set(in_channels)) == 1, "Each level must have the same channel!" in_channels = in_channels[0] + self.in_channels_to_top_module = in_channels + for head in head_configs: tower = [] num_convs, use_deformable = head_configs[head] diff --git a/adet/utils/comm.py b/adet/utils/comm.py index 802bf76da..1e9202da2 100644 --- a/adet/utils/comm.py +++ b/adet/utils/comm.py @@ -52,4 +52,4 @@ def compute_locations(h, w, stride, device): shift_x = shift_x.reshape(-1) shift_y = shift_y.reshape(-1) locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2 - return locations \ No newline at end of file + return locations diff --git a/configs/CondInst/Base-CondInst.yaml b/configs/CondInst/Base-CondInst.yaml new file mode 100644 index 000000000..264ba6483 --- /dev/null +++ b/configs/CondInst/Base-CondInst.yaml @@ -0,0 +1,26 @@ +MODEL: + META_ARCHITECTURE: "CondInst" + MASK_ON: True + BACKBONE: + NAME: "build_fcos_resnet_fpn_backbone" + RESNETS: + OUT_FEATURES: ["res3", "res4", "res5"] + FPN: + IN_FEATURES: ["res3", "res4", "res5"] + PROPOSAL_GENERATOR: + NAME: "FCOS" + FCOS: + THRESH_WITH_CTR: True + USE_SCALE: True + CONDINST: + MAX_PROPOSALS: 500 +DATASETS: + TRAIN: ("coco_2017_train",) + TEST: ("coco_2017_val",) +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.01 + STEPS: (60000, 80000) + MAX_ITER: 90000 +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) diff --git a/configs/CondInst/MS_R_101_1x.yaml b/configs/CondInst/MS_R_101_1x.yaml new file mode 100644 index 000000000..e8b730471 --- /dev/null +++ b/configs/CondInst/MS_R_101_1x.yaml @@ -0,0 +1,6 @@ +_BASE_: "Base-CondInst.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + RESNETS: + DEPTH: 101 +OUTPUT_DIR: "output/condinst_MS_R_101_1x" diff --git a/configs/CondInst/MS_R_101_3x.yaml b/configs/CondInst/MS_R_101_3x.yaml new file mode 100644 index 000000000..d87efba3d --- /dev/null +++ b/configs/CondInst/MS_R_101_3x.yaml @@ -0,0 +1,9 @@ +_BASE_: "Base-CondInst.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + RESNETS: + DEPTH: 101 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 +OUTPUT_DIR: "output/condinst_MS_R_101_3x" diff --git a/configs/CondInst/MS_R_101_3x_sem.yaml b/configs/CondInst/MS_R_101_3x_sem.yaml new file mode 100644 index 000000000..62cb2ad0a --- /dev/null +++ b/configs/CondInst/MS_R_101_3x_sem.yaml @@ -0,0 +1,12 @@ +_BASE_: "Base-CondInst.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + RESNETS: + DEPTH: 101 + CONDINST: + MASK_BRANCH: + SEMANTIC_LOSS_ON: True +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 +OUTPUT_DIR: "output/condinst_MS_R_101_3x_sem" diff --git a/configs/CondInst/MS_R_101_BiFPN_3x.yaml b/configs/CondInst/MS_R_101_BiFPN_3x.yaml new file mode 100644 index 000000000..dfc55100f --- /dev/null +++ b/configs/CondInst/MS_R_101_BiFPN_3x.yaml @@ -0,0 +1,15 @@ +_BASE_: "Base-CondInst.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + BACKBONE: + NAME: "build_fcos_resnet_bifpn_backbone" + RESNETS: + DEPTH: 101 + BiFPN: + IN_FEATURES: ["res3", "res4", "res5"] + OUT_CHANNELS: 160 + NORM: "SyncBN" +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 +OUTPUT_DIR: "output/condinst_MS_R_101_3x_bifpn" diff --git a/configs/CondInst/MS_R_101_BiFPN_3x_sem.yaml b/configs/CondInst/MS_R_101_BiFPN_3x_sem.yaml new file mode 100644 index 000000000..4d19e33ad --- /dev/null +++ b/configs/CondInst/MS_R_101_BiFPN_3x_sem.yaml @@ -0,0 +1,18 @@ +_BASE_: "Base-CondInst.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + BACKBONE: + NAME: "build_fcos_resnet_bifpn_backbone" + RESNETS: + DEPTH: 101 + BiFPN: + IN_FEATURES: ["res3", "res4", "res5"] + OUT_CHANNELS: 160 + NORM: "SyncBN" + CONDINST: + MASK_BRANCH: + SEMANTIC_LOSS_ON: True +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 +OUTPUT_DIR: "output/condinst_MS_R_101_3x_bifpn_sem" diff --git a/configs/CondInst/MS_R_50_1x.yaml b/configs/CondInst/MS_R_50_1x.yaml new file mode 100644 index 000000000..5271f7f0d --- /dev/null +++ b/configs/CondInst/MS_R_50_1x.yaml @@ -0,0 +1,6 @@ +_BASE_: "Base-CondInst.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 +OUTPUT_DIR: "output/condinst_MS_R_50_1x" diff --git a/configs/CondInst/MS_R_50_3x.yaml b/configs/CondInst/MS_R_50_3x.yaml new file mode 100644 index 000000000..fd115a6cc --- /dev/null +++ b/configs/CondInst/MS_R_50_3x.yaml @@ -0,0 +1,9 @@ +_BASE_: "Base-CondInst.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 +OUTPUT_DIR: "output/condinst_MS_R_50_3x" diff --git a/configs/CondInst/MS_R_50_3x_sem.yaml b/configs/CondInst/MS_R_50_3x_sem.yaml new file mode 100644 index 000000000..0a8f9fb1f --- /dev/null +++ b/configs/CondInst/MS_R_50_3x_sem.yaml @@ -0,0 +1,12 @@ +_BASE_: "Base-CondInst.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + RESNETS: + DEPTH: 50 + CONDINST: + MASK_BRANCH: + SEMANTIC_LOSS_ON: True +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 +OUTPUT_DIR: "output/condinst_MS_R_50_3x_sem" diff --git a/configs/CondInst/MS_R_50_BiFPN_1x.yaml b/configs/CondInst/MS_R_50_BiFPN_1x.yaml new file mode 100644 index 000000000..ffb189030 --- /dev/null +++ b/configs/CondInst/MS_R_50_BiFPN_1x.yaml @@ -0,0 +1,12 @@ +_BASE_: "Base-CondInst.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + BACKBONE: + NAME: "build_fcos_resnet_bifpn_backbone" + RESNETS: + DEPTH: 50 + BiFPN: + IN_FEATURES: ["res3", "res4", "res5"] + OUT_CHANNELS: 160 + NORM: "SyncBN" +OUTPUT_DIR: "output/condinst_MS_R_50_1x_bifpn" diff --git a/configs/CondInst/MS_R_50_BiFPN_3x.yaml b/configs/CondInst/MS_R_50_BiFPN_3x.yaml new file mode 100644 index 000000000..a8c4f335d --- /dev/null +++ b/configs/CondInst/MS_R_50_BiFPN_3x.yaml @@ -0,0 +1,15 @@ +_BASE_: "Base-CondInst.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + BACKBONE: + NAME: "build_fcos_resnet_bifpn_backbone" + RESNETS: + DEPTH: 50 + BiFPN: + IN_FEATURES: ["res3", "res4", "res5"] + OUT_CHANNELS: 160 + NORM: "SyncBN" +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 +OUTPUT_DIR: "output/condinst_MS_R_50_3x_bifpn" diff --git a/configs/CondInst/MS_R_50_BiFPN_3x_sem.yaml b/configs/CondInst/MS_R_50_BiFPN_3x_sem.yaml new file mode 100644 index 000000000..2fec59dec --- /dev/null +++ b/configs/CondInst/MS_R_50_BiFPN_3x_sem.yaml @@ -0,0 +1,18 @@ +_BASE_: "Base-CondInst.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + BACKBONE: + NAME: "build_fcos_resnet_bifpn_backbone" + RESNETS: + DEPTH: 50 + BiFPN: + IN_FEATURES: ["res3", "res4", "res5"] + OUT_CHANNELS: 160 + NORM: "SyncBN" + CONDINST: + MASK_BRANCH: + SEMANTIC_LOSS_ON: True +SOLVER: + STEPS: (210000, 250000) + MAX_ITER: 270000 +OUTPUT_DIR: "output/condinst_MS_R_50_3x_bifpn_sem" diff --git a/configs/CondInst/README.md b/configs/CondInst/README.md new file mode 100644 index 000000000..0bbf64e8e --- /dev/null +++ b/configs/CondInst/README.md @@ -0,0 +1,83 @@ +# Conditional Convolutions for Instance Segmentation (Oral) + + Conditional Convolutions for Instance Segmentation; + Zhi Tian, Chunhua Shen and Hao Chen; + In: Proc. European Conference on Computer Vision (ECCV), 2020. + arXiv preprint arXiv:2003.05664 + +[[`Paper`](https://arxiv.org/abs/2003.05664)] [[`BibTeX`](#citing-condinst)] + +# Installation & Quick Start +First, follow the [default instruction](../../README.md#Installation) to install the project, and +follow [datasets/README.md](https://github.com/facebookresearch/detectron2/blob/master/datasets/README.md) +set up the datasets (e.g., MS-COCO). + +For demo, run the following command lines: +``` +wget https://cloudstor.aarnet.edu.au/plus/s/M8nNxSR5iNP4qyO/download -O CondInst_MS_R_101_3x_sem.pth +python demo/demo.py \ + --config-file configs/CondInst/MS_R_101_3x_sem.yaml \ + --input input1.jpg input2.jpg \ + --opts MODEL.WEIGHTS CondInst_MS_R_101_3x_sem.pth +``` + +For training on COCO, run: +``` +OMP_NUM_THREADS=1 python tools/train_net.py \ + --config-file configs/CondInst/MS_R_50_1x.yaml \ + --num-gpus 8 \ + OUTPUT_DIR training_dir/CondInst_MS_R_50_1x +``` + +For evaluation on COCO, run: +``` +OMP_NUM_THREADS=1 python tools/train_net.py \ + --config-file configs/CondInst/MS_R_50_1x.yaml \ + --eval-only \ + --num-gpus 8 \ + OUTPUT_DIR training_dir/CondInst_MS_R_50_1x \ + MODEL.WEIGHTS training_dir/CondInst_MS_R_50_1x/model_final.pth +``` + + +## Models +### COCO Instance Segmentation Baselines with [CondInst](https://arxiv.org/abs/2003.05664) + +Name | inf. time | box AP | mask AP | download +--- |:---:|:---:|:---:|:---: +[CondInst_MS_R_50_1x](MS_R_50_1x.yaml) | - | 39.7 | 35.7 | [model](https://cloudstor.aarnet.edu.au/plus/s/Trx1r4tLJja7sLT/download) +[CondInst_MS_R_50_3x](MS_R_50_3x.yaml) | - | 41.9 | 37.5 | [model](https://cloudstor.aarnet.edu.au/plus/s/T3OGVBiaSVLvo5E/download) +[CondInst_MS_R_101_3x](MS_R_101_3x.yaml) | - | 43.3 | 38.6 | [model](https://cloudstor.aarnet.edu.au/plus/s/vWLiYm8OnrTSUD2/download) + +With semantic segmentation loss (set `MODEL.CONDINST.MASK_BRANCH.SEMANTIC_LOSS_ON = True` to enable it): + +Name | inf. time | box AP | mask AP | mask AP (test-dev) | download +--- |:---:|:---:|:---:|:---:|:---: +[CondInst_MS_R_50_3x_sem](MS_R_50_3x_sem.yaml) | - | 42.6 | 38.2 | 38.7 | [model](https://cloudstor.aarnet.edu.au/plus/s/75Ag8VvC6WedVNh/download) +[CondInst_MS_R_101_3x_sem](MS_R_101_3x_sem.yaml) | - | 44.6 | 39.8 | 40.1 | [model](https://cloudstor.aarnet.edu.au/plus/s/M8nNxSR5iNP4qyO/download) + +With BiFPN: + +Name | inf. time | box AP | mask AP | download +--- |:---:|:---:|:---:|:---: +[CondInst_MS_R_50_BiFPN_1x](MS_R_50_BiFPN_1x.yaml) | - | 42.5 | 37.3 | [model](https://cloudstor.aarnet.edu.au/plus/s/RyCG82WhTop99j2/download) +[CondInst_MS_R_50_BiFPN_3x](MS_R_50_BiFPN_3x.yaml) | - | 44.3 | 38.9 | [model](https://cloudstor.aarnet.edu.au/plus/s/W9ZCcxJF0P5NhJQ/download) +[CondInst_MS_R_50_BiFPN_3x_sem](MS_R_50_BiFPN_3x_sem.yaml) | - | 44.7 | 39.4 | [model](https://cloudstor.aarnet.edu.au/plus/s/9cAHjZtdaAGnb2Q/download) +[CondInst_MS_R_101_BiFPN_3x](MS_R_101_BiFPN_3x.yaml) | - | 45.3 | 39.6 | [model](https://cloudstor.aarnet.edu.au/plus/s/HyB0O0D7hfpUC2n/download) + + +*Disclaimer:* +- All other models are trained with multi-scale data augmentation. Inference time is measured on a NVIDIA 1080Ti with batch size 1. +- The final mask's resolution is 1/4 of the input image (i.e., `MODEL.CONDINST.MASK_OUT_STRIDE = 4`, which is different from our paper. We used `MODEL.CONDINST.MASK_OUT_STRIDE = 2` in our paper. If you want high-resolution mask results, please change it. +- This is a reimplementation, and thus the numbers sometimes are slightly different (~0.1% in mask AP). + +# Citing CondInst +If you use CondInst in your research or wish to refer to the baseline results, please use the following BibTeX entries. +```BibTeX +@inproceedings{tian2020conditional, + title = {{FCOS}: Fully Convolutional One-Stage Object Detection}, + author = {Tian, Zhi and Shen, Chunhua and Chen, Hao}, + booktitle = {Proc. Eur. Conf. Computer Vision (ECCV)}, + year = {2020} +} +```