diff --git a/detection/mmdet_custom/models/backbones/intern_image.py b/detection/mmdet_custom/models/backbones/intern_image.py index dc87c8d9..a3737bcb 100644 --- a/detection/mmdet_custom/models/backbones/intern_image.py +++ b/detection/mmdet_custom/models/backbones/intern_image.py @@ -15,7 +15,7 @@ from mmdet.models.builder import BACKBONES import torch.nn.functional as F -from ops_dcnv3 import modules as dcnv3 +from ops_dcnv3 import modules as opsm class to_channels_first(nn.Module): @@ -365,8 +365,7 @@ def __init__(self, with_cp=False, dw_kernel_size=None, # for InternImage-H/G res_post_norm=False, # for InternImage-H/G - center_feature_scale=False, - use_dcn_v4_op=False): # for InternImage-H/G + center_feature_scale=False): # for InternImage-H/G super().__init__() self.channels = channels self.groups = groups @@ -386,8 +385,7 @@ def __init__(self, act_layer=act_layer, norm_layer=norm_layer, dw_kernel_size=dw_kernel_size, # for InternImage-H/G - center_feature_scale=center_feature_scale, - use_dcn_v4_op=use_dcn_v4_op) # for InternImage-H/G + center_feature_scale=center_feature_scale) # for InternImage-H/G self.drop_path = DropPath(drop_path) if drop_path > 0. \ else nn.Identity() self.norm2 = build_norm_layer(channels, 'LN') @@ -471,8 +469,7 @@ def __init__(self, dw_kernel_size=None, # for InternImage-H/G post_norm_block_ids=None, # for InternImage-H/G res_post_norm=False, # for InternImage-H/G - center_feature_scale=False, # for InternImage-H/G - use_dcn_v4_op=False): + center_feature_scale=False): # for InternImage-H/G super().__init__() self.channels = channels self.depth = depth @@ -496,8 +493,7 @@ def __init__(self, with_cp=with_cp, dw_kernel_size=dw_kernel_size, # for InternImage-H/G res_post_norm=res_post_norm, # for InternImage-H/G - center_feature_scale=center_feature_scale, # for InternImage-H/G - use_dcn_v4_op=use_dcn_v4_op + center_feature_scale=center_feature_scale # for InternImage-H/G ) for i in range(depth) ]) if not self.post_norm or center_feature_scale: @@ -573,11 +569,12 @@ def __init__(self, level2_post_norm_block_ids=None, # for InternImage-H/G res_post_norm=False, # for InternImage-H/G center_feature_scale=False, # for InternImage-H/G - use_dcn_v4_op=False, out_indices=(0, 1, 2, 3), init_cfg=None, + frozen_stages=-1, # you can freez level 1 -> num_levels(len(depths)) **kwargs): super().__init__() + self.frozen_stages = frozen_stages self.core_op = core_op self.num_levels = len(depths) self.depths = depths @@ -596,7 +593,6 @@ def __init__(self, logger.info(f"level2_post_norm: {level2_post_norm}") logger.info(f"level2_post_norm_block_ids: {level2_post_norm_block_ids}") logger.info(f"res_post_norm: {res_post_norm}") - logger.info(f"use_dcn_v4_op: {use_dcn_v4_op}") in_chans = 3 self.patch_embed = StemLayer(in_chans=in_chans, @@ -617,7 +613,7 @@ def __init__(self, post_norm_block_ids = level2_post_norm_block_ids if level2_post_norm and ( i == 2) else None # for InternImage-H/G level = InternImageBlock( - core_op=getattr(dcnv3, core_op), + core_op=getattr(opsm, core_op), channels=int(channels * 2**i), depth=depths[i], groups=groups[i], @@ -634,14 +630,27 @@ def __init__(self, dw_kernel_size=dw_kernel_size, # for InternImage-H/G post_norm_block_ids=post_norm_block_ids, # for InternImage-H/G res_post_norm=res_post_norm, # for InternImage-H/G - center_feature_scale=center_feature_scale, # for InternImage-H/G - use_dcn_v4_op=use_dcn_v4_op, + center_feature_scale=center_feature_scale # for InternImage-H/G ) self.levels.append(level) self.num_layers = len(depths) self.apply(self._init_weights) self.apply(self._init_deform_weights) + self._freeze_stages() + + def train(self, mode=True): + """Convert the model into training mode while keep normalization layer + freezed.""" + super(InternImage, self).train(mode) + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for level in self.levels[:self.frozen_stages]: + level.eval() + for param in level.parameters(): + param.requires_grad = False def init_weights(self): logger = get_root_logger() @@ -694,7 +703,7 @@ def _init_weights(self, m): nn.init.constant_(m.weight, 1.0) def _init_deform_weights(self, m): - if isinstance(m, getattr(dcnv3, self.core_op)): + if isinstance(m, getattr(opsm, self.core_op)): m._reset_parameters() def forward(self, x):