|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
|
|
|
from mmaction.models import ResNet3dCSN
|
|
from mmaction.testing import generate_backbone_demo_inputs
|
|
|
|
|
|
def test_resnet_csn_backbone():
|
|
"""Test resnet_csn backbone."""
|
|
with pytest.raises(ValueError):
|
|
|
|
ResNet3dCSN(152, None, bottleneck_mode='id')
|
|
|
|
input_shape = (2, 3, 6, 64, 64)
|
|
imgs = generate_backbone_demo_inputs(input_shape)
|
|
|
|
resnet3d_csn_frozen = ResNet3dCSN(
|
|
152, None, bn_frozen=True, norm_eval=True)
|
|
resnet3d_csn_frozen.train()
|
|
for m in resnet3d_csn_frozen.modules():
|
|
if isinstance(m, _BatchNorm):
|
|
for param in m.parameters():
|
|
assert param.requires_grad is False
|
|
|
|
|
|
resnet3d_csn_ip = ResNet3dCSN(152, None, bottleneck_mode='ip')
|
|
resnet3d_csn_ip.init_weights()
|
|
resnet3d_csn_ip.train()
|
|
for i, layer_name in enumerate(resnet3d_csn_ip.res_layers):
|
|
layers = getattr(resnet3d_csn_ip, layer_name)
|
|
num_blocks = resnet3d_csn_ip.stage_blocks[i]
|
|
assert len(layers) == num_blocks
|
|
for layer in layers:
|
|
assert isinstance(layer.conv2, nn.Sequential)
|
|
assert len(layer.conv2) == 2
|
|
assert layer.conv2[1].groups == layer.planes
|
|
if torch.__version__ == 'parrots':
|
|
if torch.cuda.is_available():
|
|
resnet3d_csn_ip = resnet3d_csn_ip.cuda()
|
|
imgs_gpu = imgs.cuda()
|
|
feat = resnet3d_csn_ip(imgs_gpu)
|
|
assert feat.shape == torch.Size([2, 2048, 1, 2, 2])
|
|
else:
|
|
feat = resnet3d_csn_ip(imgs)
|
|
assert feat.shape == torch.Size([2, 2048, 1, 2, 2])
|
|
|
|
|
|
resnet3d_csn_ir = ResNet3dCSN(152, None, bottleneck_mode='ir')
|
|
resnet3d_csn_ir.init_weights()
|
|
resnet3d_csn_ir.train()
|
|
for i, layer_name in enumerate(resnet3d_csn_ir.res_layers):
|
|
layers = getattr(resnet3d_csn_ir, layer_name)
|
|
num_blocks = resnet3d_csn_ir.stage_blocks[i]
|
|
assert len(layers) == num_blocks
|
|
for layer in layers:
|
|
assert isinstance(layer.conv2, nn.Sequential)
|
|
assert len(layer.conv2) == 1
|
|
assert layer.conv2[0].groups == layer.planes
|
|
if torch.__version__ == 'parrots':
|
|
if torch.cuda.is_available():
|
|
resnet3d_csn_ir = resnet3d_csn_ir.cuda()
|
|
imgs_gpu = imgs.cuda()
|
|
feat = resnet3d_csn_ir(imgs_gpu)
|
|
assert feat.shape == torch.Size([2, 2048, 1, 2, 2])
|
|
else:
|
|
feat = resnet3d_csn_ir(imgs)
|
|
assert feat.shape == torch.Size([2, 2048, 1, 2, 2])
|
|
|
|
|
|
resnet3d_csn_ip = ResNet3dCSN(152, None, bottleneck_mode='ip')
|
|
resnet3d_csn_ip.init_weights()
|
|
resnet3d_csn_ip.train(False)
|
|
for module in resnet3d_csn_ip.children():
|
|
assert module.training is False
|
|
|