|
|
|
import pytest
|
|
import torch
|
|
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
|
|
|
from mmaction.models import ResNet2Plus1d
|
|
from mmaction.testing import generate_backbone_demo_inputs
|
|
|
|
|
|
def test_resnet2plus1d_backbone():
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
ResNet2Plus1d(50, None, pretrained2d=True)
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
ResNet2Plus1d(
|
|
50, None, pretrained2d=False, conv_cfg=dict(type='Conv3d'))
|
|
|
|
frozen_stages = 1
|
|
r2plus1d_34_frozen = ResNet2Plus1d(
|
|
34,
|
|
None,
|
|
conv_cfg=dict(type='Conv2plus1d'),
|
|
pretrained2d=False,
|
|
frozen_stages=frozen_stages,
|
|
conv1_kernel=(3, 7, 7),
|
|
conv1_stride_t=1,
|
|
pool1_stride_t=1,
|
|
inflate=(1, 1, 1, 1),
|
|
spatial_strides=(1, 2, 2, 2),
|
|
temporal_strides=(1, 2, 2, 2))
|
|
r2plus1d_34_frozen.init_weights()
|
|
r2plus1d_34_frozen.train()
|
|
assert r2plus1d_34_frozen.conv1.conv.bn_s.training is False
|
|
assert r2plus1d_34_frozen.conv1.bn.training is False
|
|
for param in r2plus1d_34_frozen.conv1.parameters():
|
|
assert param.requires_grad is False
|
|
for i in range(1, frozen_stages + 1):
|
|
layer = getattr(r2plus1d_34_frozen, f'layer{i}')
|
|
for mod in layer.modules():
|
|
if isinstance(mod, _BatchNorm):
|
|
assert mod.training is False
|
|
for param in layer.parameters():
|
|
assert param.requires_grad is False
|
|
input_shape = (1, 3, 8, 64, 64)
|
|
imgs = generate_backbone_demo_inputs(input_shape)
|
|
|
|
if torch.__version__ == 'parrots':
|
|
if torch.cuda.is_available():
|
|
r2plus1d_34_frozen = r2plus1d_34_frozen.cuda()
|
|
imgs_gpu = imgs.cuda()
|
|
feat = r2plus1d_34_frozen(imgs_gpu)
|
|
assert feat.shape == torch.Size([1, 512, 1, 2, 2])
|
|
else:
|
|
feat = r2plus1d_34_frozen(imgs)
|
|
assert feat.shape == torch.Size([1, 512, 1, 2, 2])
|
|
|
|
r2plus1d_50_frozen = ResNet2Plus1d(
|
|
50,
|
|
None,
|
|
conv_cfg=dict(type='Conv2plus1d'),
|
|
pretrained2d=False,
|
|
conv1_kernel=(3, 7, 7),
|
|
conv1_stride_t=1,
|
|
pool1_stride_t=1,
|
|
inflate=(1, 1, 1, 1),
|
|
spatial_strides=(1, 2, 2, 2),
|
|
temporal_strides=(1, 2, 2, 2),
|
|
frozen_stages=frozen_stages)
|
|
r2plus1d_50_frozen.init_weights()
|
|
|
|
r2plus1d_50_frozen.train()
|
|
assert r2plus1d_50_frozen.conv1.conv.bn_s.training is False
|
|
assert r2plus1d_50_frozen.conv1.bn.training is False
|
|
for param in r2plus1d_50_frozen.conv1.parameters():
|
|
assert param.requires_grad is False
|
|
for i in range(1, frozen_stages + 1):
|
|
layer = getattr(r2plus1d_50_frozen, f'layer{i}')
|
|
for mod in layer.modules():
|
|
if isinstance(mod, _BatchNorm):
|
|
assert mod.training is False
|
|
for param in layer.parameters():
|
|
assert param.requires_grad is False
|
|
input_shape = (1, 3, 8, 64, 64)
|
|
imgs = generate_backbone_demo_inputs(input_shape)
|
|
|
|
|
|
if torch.__version__ == 'parrots':
|
|
if torch.cuda.is_available():
|
|
r2plus1d_50_frozen = r2plus1d_50_frozen.cuda()
|
|
imgs_gpu = imgs.cuda()
|
|
feat = r2plus1d_50_frozen(imgs_gpu)
|
|
assert feat.shape == torch.Size([1, 2048, 1, 2, 2])
|
|
else:
|
|
feat = r2plus1d_50_frozen(imgs)
|
|
assert feat.shape == torch.Size([1, 2048, 1, 2, 2])
|
|
|