mmaction2 / tests /models /backbones /test_resnet_tsm.py
niobures's picture
mmaction2
d3dbf03 verified
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from unittest import TestCase
import pytest
import torch
import torch.nn as nn
from mmaction.models import ResNetTSM
from mmaction.models.backbones.resnet import Bottleneck
from mmaction.models.backbones.resnet_tsm import NL3DWrapper, TemporalShift
from mmaction.testing import generate_backbone_demo_inputs
class Test_ResNet_TSM(TestCase):
def setUp(self):
input_shape = (8, 3, 64, 64)
self.imgs = generate_backbone_demo_inputs(input_shape)
def test_init(self):
with pytest.raises(NotImplementedError):
# shift_place must be block or blockres
resnet_tsm_50_block = ResNetTSM(50, shift_place='Block')
resnet_tsm_50_block.init_weights()
def test_init_from_scratch(self):
resnet_tsm_50 = ResNetTSM(50, pretrained=None, pretrained2d=False)
resnet_tsm_50.init_weights()
def test_resnet_tsm_temporal_shift_blockres(self):
# resnet_tsm with depth 50
resnet_tsm_50 = ResNetTSM(50, pretrained='torchvision://resnet50')
resnet_tsm_50.init_weights()
for layer_name in resnet_tsm_50.res_layers:
layer = getattr(resnet_tsm_50, layer_name)
blocks = list(layer.children())
for block in blocks:
assert isinstance(block.conv1.conv, TemporalShift)
assert block.conv1.conv.num_segments == resnet_tsm_50.num_segments # noqa: E501
assert block.conv1.conv.shift_div == resnet_tsm_50.shift_div
assert isinstance(block.conv1.conv.net, nn.Conv2d)
feat = resnet_tsm_50(self.imgs)
assert feat.shape == torch.Size([8, 2048, 2, 2])
def test_resnet_tsm_temporal_shift_block(self):
# resnet_tsm with depth 50, no pretrained, shift_place is block
resnet_tsm_50_block = ResNetTSM(
50, shift_place='block', pretrained='torchvision://resnet50')
resnet_tsm_50_block.init_weights()
for layer_name in resnet_tsm_50_block.res_layers:
layer = getattr(resnet_tsm_50_block, layer_name)
blocks = list(layer.children())
for block in blocks:
assert isinstance(block, TemporalShift)
assert block.num_segments == resnet_tsm_50_block.num_segments
assert block.num_segments == resnet_tsm_50_block.num_segments
assert block.shift_div == resnet_tsm_50_block.shift_div
assert isinstance(block.net, Bottleneck)
def test_resnet_tsm_temporal_pool(self):
# resnet_tsm with depth 50, no pretrained, use temporal_pool
resnet_tsm_50_temporal_pool = ResNetTSM(
50, temporal_pool=True, pretrained='torchvision://resnet50')
resnet_tsm_50_temporal_pool.init_weights()
for layer_name in resnet_tsm_50_temporal_pool.res_layers:
layer = getattr(resnet_tsm_50_temporal_pool, layer_name)
blocks = list(layer.children())
if layer_name == 'layer2':
assert len(blocks) == 2
assert isinstance(blocks[1], nn.MaxPool3d)
blocks = copy.deepcopy(blocks[0])
for block in blocks:
assert isinstance(block.conv1.conv, TemporalShift)
if layer_name == 'layer1':
assert block.conv1.conv.num_segments == \
resnet_tsm_50_temporal_pool.num_segments
else:
assert block.conv1.conv.num_segments == \
resnet_tsm_50_temporal_pool.num_segments // 2
assert block.conv1.conv.shift_div == resnet_tsm_50_temporal_pool.shift_div # noqa: E501
assert isinstance(block.conv1.conv.net, nn.Conv2d)
feat = resnet_tsm_50_temporal_pool(self.imgs)
assert feat.shape == torch.Size([4, 2048, 2, 2])
def test_resnet_tsm_non_local(self):
# resnet_tsm with non-local module
non_local_cfg = dict(
sub_sample=True,
use_scale=False,
norm_cfg=dict(type='BN3d', requires_grad=True),
mode='embedded_gaussian')
non_local = ((0, 0, 0), (1, 0, 1, 0), (1, 0, 1, 0, 1, 0), (0, 0, 0))
resnet_tsm_nonlocal = ResNetTSM(
50,
non_local=non_local,
non_local_cfg=non_local_cfg,
pretrained='torchvision://resnet50')
resnet_tsm_nonlocal.init_weights()
for layer_name in ['layer2', 'layer3']:
layer = getattr(resnet_tsm_nonlocal, layer_name)
for i, _ in enumerate(layer):
if i % 2 == 0:
assert isinstance(layer[i], NL3DWrapper)
feat = resnet_tsm_nonlocal(self.imgs)
assert feat.shape == torch.Size([8, 2048, 2, 2])
def test_resnet_tsm_full(self):
non_local_cfg = dict(
sub_sample=True,
use_scale=False,
norm_cfg=dict(type='BN3d', requires_grad=True),
mode='embedded_gaussian')
non_local = ((0, 0, 0), (1, 0, 1, 0), (1, 0, 1, 0, 1, 0), (0, 0, 0))
resnet_tsm_50_full = ResNetTSM(
50,
pretrained='torchvision://resnet50',
non_local=non_local,
non_local_cfg=non_local_cfg,
temporal_pool=True)
resnet_tsm_50_full.init_weights()
input_shape = (16, 3, 32, 32)
imgs = generate_backbone_demo_inputs(input_shape)
feat = resnet_tsm_50_full(imgs)
assert feat.shape == torch.Size([8, 2048, 1, 1])