File size: 3,736 Bytes
d3dbf03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmaction.models import ResNet2Plus1d
from mmaction.testing import generate_backbone_demo_inputs
def test_resnet2plus1d_backbone():
# Test r2+1d backbone
with pytest.raises(AssertionError):
# r2+1d does not support inflation
ResNet2Plus1d(50, None, pretrained2d=True)
with pytest.raises(AssertionError):
# r2+1d requires conv(2+1)d module
ResNet2Plus1d(
50, None, pretrained2d=False, conv_cfg=dict(type='Conv3d'))
frozen_stages = 1
r2plus1d_34_frozen = ResNet2Plus1d(
34,
None,
conv_cfg=dict(type='Conv2plus1d'),
pretrained2d=False,
frozen_stages=frozen_stages,
conv1_kernel=(3, 7, 7),
conv1_stride_t=1,
pool1_stride_t=1,
inflate=(1, 1, 1, 1),
spatial_strides=(1, 2, 2, 2),
temporal_strides=(1, 2, 2, 2))
r2plus1d_34_frozen.init_weights()
r2plus1d_34_frozen.train()
assert r2plus1d_34_frozen.conv1.conv.bn_s.training is False
assert r2plus1d_34_frozen.conv1.bn.training is False
for param in r2plus1d_34_frozen.conv1.parameters():
assert param.requires_grad is False
for i in range(1, frozen_stages + 1):
layer = getattr(r2plus1d_34_frozen, f'layer{i}')
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False
input_shape = (1, 3, 8, 64, 64)
imgs = generate_backbone_demo_inputs(input_shape)
# parrots 3dconv is only implemented on gpu
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
r2plus1d_34_frozen = r2plus1d_34_frozen.cuda()
imgs_gpu = imgs.cuda()
feat = r2plus1d_34_frozen(imgs_gpu)
assert feat.shape == torch.Size([1, 512, 1, 2, 2])
else:
feat = r2plus1d_34_frozen(imgs)
assert feat.shape == torch.Size([1, 512, 1, 2, 2])
r2plus1d_50_frozen = ResNet2Plus1d(
50,
None,
conv_cfg=dict(type='Conv2plus1d'),
pretrained2d=False,
conv1_kernel=(3, 7, 7),
conv1_stride_t=1,
pool1_stride_t=1,
inflate=(1, 1, 1, 1),
spatial_strides=(1, 2, 2, 2),
temporal_strides=(1, 2, 2, 2),
frozen_stages=frozen_stages)
r2plus1d_50_frozen.init_weights()
r2plus1d_50_frozen.train()
assert r2plus1d_50_frozen.conv1.conv.bn_s.training is False
assert r2plus1d_50_frozen.conv1.bn.training is False
for param in r2plus1d_50_frozen.conv1.parameters():
assert param.requires_grad is False
for i in range(1, frozen_stages + 1):
layer = getattr(r2plus1d_50_frozen, f'layer{i}')
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False
input_shape = (1, 3, 8, 64, 64)
imgs = generate_backbone_demo_inputs(input_shape)
# parrots 3dconv is only implemented on gpu
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
r2plus1d_50_frozen = r2plus1d_50_frozen.cuda()
imgs_gpu = imgs.cuda()
feat = r2plus1d_50_frozen(imgs_gpu)
assert feat.shape == torch.Size([1, 2048, 1, 2, 2])
else:
feat = r2plus1d_50_frozen(imgs)
assert feat.shape == torch.Size([1, 2048, 1, 2, 2])
|