mmaction2 / tests /models /backbones /test_resnet3d_csn.py
niobures's picture
mmaction2
d3dbf03 verified
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
import torch.nn as nn
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmaction.models import ResNet3dCSN
from mmaction.testing import generate_backbone_demo_inputs
def test_resnet_csn_backbone():
"""Test resnet_csn backbone."""
with pytest.raises(ValueError):
# Bottleneck mode must be "ip" or "ir"
ResNet3dCSN(152, None, bottleneck_mode='id')
input_shape = (2, 3, 6, 64, 64)
imgs = generate_backbone_demo_inputs(input_shape)
resnet3d_csn_frozen = ResNet3dCSN(
152, None, bn_frozen=True, norm_eval=True)
resnet3d_csn_frozen.train()
for m in resnet3d_csn_frozen.modules():
if isinstance(m, _BatchNorm):
for param in m.parameters():
assert param.requires_grad is False
# Interaction-preserved channel-separated bottleneck block
resnet3d_csn_ip = ResNet3dCSN(152, None, bottleneck_mode='ip')
resnet3d_csn_ip.init_weights()
resnet3d_csn_ip.train()
for i, layer_name in enumerate(resnet3d_csn_ip.res_layers):
layers = getattr(resnet3d_csn_ip, layer_name)
num_blocks = resnet3d_csn_ip.stage_blocks[i]
assert len(layers) == num_blocks
for layer in layers:
assert isinstance(layer.conv2, nn.Sequential)
assert len(layer.conv2) == 2
assert layer.conv2[1].groups == layer.planes
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
resnet3d_csn_ip = resnet3d_csn_ip.cuda()
imgs_gpu = imgs.cuda()
feat = resnet3d_csn_ip(imgs_gpu)
assert feat.shape == torch.Size([2, 2048, 1, 2, 2])
else:
feat = resnet3d_csn_ip(imgs)
assert feat.shape == torch.Size([2, 2048, 1, 2, 2])
# Interaction-reduced channel-separated bottleneck block
resnet3d_csn_ir = ResNet3dCSN(152, None, bottleneck_mode='ir')
resnet3d_csn_ir.init_weights()
resnet3d_csn_ir.train()
for i, layer_name in enumerate(resnet3d_csn_ir.res_layers):
layers = getattr(resnet3d_csn_ir, layer_name)
num_blocks = resnet3d_csn_ir.stage_blocks[i]
assert len(layers) == num_blocks
for layer in layers:
assert isinstance(layer.conv2, nn.Sequential)
assert len(layer.conv2) == 1
assert layer.conv2[0].groups == layer.planes
if torch.__version__ == 'parrots':
if torch.cuda.is_available():
resnet3d_csn_ir = resnet3d_csn_ir.cuda()
imgs_gpu = imgs.cuda()
feat = resnet3d_csn_ir(imgs_gpu)
assert feat.shape == torch.Size([2, 2048, 1, 2, 2])
else:
feat = resnet3d_csn_ir(imgs)
assert feat.shape == torch.Size([2, 2048, 1, 2, 2])
# Set training status = False
resnet3d_csn_ip = ResNet3dCSN(152, None, bottleneck_mode='ip')
resnet3d_csn_ip.init_weights()
resnet3d_csn_ip.train(False)
for module in resnet3d_csn_ip.children():
assert module.training is False