mmaction2 / tests /models /backbones /test_resnet2plus1d.py
niobures's picture
mmaction2
d3dbf03 verified
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmaction.models import ResNet2Plus1d
from mmaction.testing import generate_backbone_demo_inputs
def test_resnet2plus1d_backbone():
# Test r2+1d backbone
with pytest.raises(AssertionError):
# r2+1d does not support inflation
ResNet2Plus1d(50, None, pretrained2d=True)
with pytest.raises(AssertionError):
# r2+1d requires conv(2+1)d module
ResNet2Plus1d(
50, None, pretrained2d=False, conv_cfg=dict(type='Conv3d'))
frozen_stages = 1
r2plus1d_34_frozen = ResNet2Plus1d(
34,
None,
conv_cfg=dict(type='Conv2plus1d'),
pretrained2d=False,
frozen_stages=frozen_stages,
conv1_kernel=(3, 7, 7),
conv1_stride_t=1,
pool1_stride_t=1,
inflate=(1, 1, 1, 1),
spatial_strides=(1, 2, 2, 2),
temporal_strides=(1, 2, 2, 2))
r2plus1d_34_frozen.init_weights()
r2plus1d_34_frozen.train()
assert r2plus1d_34_frozen.conv1.conv.bn_s.training is False
assert r2plus1d_34_frozen.conv1.bn.training is False
for param in r2plus1d_34_frozen.conv1.parameters():
assert param.requires_grad is False
for i in range(1, frozen_stages + 1):
layer = getattr(r2plus1d_34_frozen, f'layer{i}')
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False
input_shape = (1, 3, 8, 64, 64)
imgs = generate_backbone_demo_inputs(input_shape)
# parrots 3dconv is only implemented on gpu
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
r2plus1d_34_frozen = r2plus1d_34_frozen.cuda()
imgs_gpu = imgs.cuda()
feat = r2plus1d_34_frozen(imgs_gpu)
assert feat.shape == torch.Size([1, 512, 1, 2, 2])
else:
feat = r2plus1d_34_frozen(imgs)
assert feat.shape == torch.Size([1, 512, 1, 2, 2])
r2plus1d_50_frozen = ResNet2Plus1d(
50,
None,
conv_cfg=dict(type='Conv2plus1d'),
pretrained2d=False,
conv1_kernel=(3, 7, 7),
conv1_stride_t=1,
pool1_stride_t=1,
inflate=(1, 1, 1, 1),
spatial_strides=(1, 2, 2, 2),
temporal_strides=(1, 2, 2, 2),
frozen_stages=frozen_stages)
r2plus1d_50_frozen.init_weights()
r2plus1d_50_frozen.train()
assert r2plus1d_50_frozen.conv1.conv.bn_s.training is False
assert r2plus1d_50_frozen.conv1.bn.training is False
for param in r2plus1d_50_frozen.conv1.parameters():
assert param.requires_grad is False
for i in range(1, frozen_stages + 1):
layer = getattr(r2plus1d_50_frozen, f'layer{i}')
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False
input_shape = (1, 3, 8, 64, 64)
imgs = generate_backbone_demo_inputs(input_shape)
# parrots 3dconv is only implemented on gpu
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
r2plus1d_50_frozen = r2plus1d_50_frozen.cuda()
imgs_gpu = imgs.cuda()
feat = r2plus1d_50_frozen(imgs_gpu)
assert feat.shape == torch.Size([1, 2048, 1, 2, 2])
else:
feat = r2plus1d_50_frozen(imgs)
assert feat.shape == torch.Size([1, 2048, 1, 2, 2])