File size: 5,693 Bytes
d3dbf03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from unittest import TestCase
import pytest
import torch
import torch.nn as nn
from mmaction.models import ResNetTSM
from mmaction.models.backbones.resnet import Bottleneck
from mmaction.models.backbones.resnet_tsm import NL3DWrapper, TemporalShift
from mmaction.testing import generate_backbone_demo_inputs
class Test_ResNet_TSM(TestCase):
def setUp(self):
input_shape = (8, 3, 64, 64)
self.imgs = generate_backbone_demo_inputs(input_shape)
def test_init(self):
with pytest.raises(NotImplementedError):
# shift_place must be block or blockres
resnet_tsm_50_block = ResNetTSM(50, shift_place='Block')
resnet_tsm_50_block.init_weights()
def test_init_from_scratch(self):
resnet_tsm_50 = ResNetTSM(50, pretrained=None, pretrained2d=False)
resnet_tsm_50.init_weights()
def test_resnet_tsm_temporal_shift_blockres(self):
# resnet_tsm with depth 50
resnet_tsm_50 = ResNetTSM(50, pretrained='torchvision://resnet50')
resnet_tsm_50.init_weights()
for layer_name in resnet_tsm_50.res_layers:
layer = getattr(resnet_tsm_50, layer_name)
blocks = list(layer.children())
for block in blocks:
assert isinstance(block.conv1.conv, TemporalShift)
assert block.conv1.conv.num_segments == resnet_tsm_50.num_segments # noqa: E501
assert block.conv1.conv.shift_div == resnet_tsm_50.shift_div
assert isinstance(block.conv1.conv.net, nn.Conv2d)
feat = resnet_tsm_50(self.imgs)
assert feat.shape == torch.Size([8, 2048, 2, 2])
def test_resnet_tsm_temporal_shift_block(self):
# resnet_tsm with depth 50, no pretrained, shift_place is block
resnet_tsm_50_block = ResNetTSM(
50, shift_place='block', pretrained='torchvision://resnet50')
resnet_tsm_50_block.init_weights()
for layer_name in resnet_tsm_50_block.res_layers:
layer = getattr(resnet_tsm_50_block, layer_name)
blocks = list(layer.children())
for block in blocks:
assert isinstance(block, TemporalShift)
assert block.num_segments == resnet_tsm_50_block.num_segments
assert block.num_segments == resnet_tsm_50_block.num_segments
assert block.shift_div == resnet_tsm_50_block.shift_div
assert isinstance(block.net, Bottleneck)
def test_resnet_tsm_temporal_pool(self):
# resnet_tsm with depth 50, no pretrained, use temporal_pool
resnet_tsm_50_temporal_pool = ResNetTSM(
50, temporal_pool=True, pretrained='torchvision://resnet50')
resnet_tsm_50_temporal_pool.init_weights()
for layer_name in resnet_tsm_50_temporal_pool.res_layers:
layer = getattr(resnet_tsm_50_temporal_pool, layer_name)
blocks = list(layer.children())
if layer_name == 'layer2':
assert len(blocks) == 2
assert isinstance(blocks[1], nn.MaxPool3d)
blocks = copy.deepcopy(blocks[0])
for block in blocks:
assert isinstance(block.conv1.conv, TemporalShift)
if layer_name == 'layer1':
assert block.conv1.conv.num_segments == \
resnet_tsm_50_temporal_pool.num_segments
else:
assert block.conv1.conv.num_segments == \
resnet_tsm_50_temporal_pool.num_segments // 2
assert block.conv1.conv.shift_div == resnet_tsm_50_temporal_pool.shift_div # noqa: E501
assert isinstance(block.conv1.conv.net, nn.Conv2d)
feat = resnet_tsm_50_temporal_pool(self.imgs)
assert feat.shape == torch.Size([4, 2048, 2, 2])
def test_resnet_tsm_non_local(self):
# resnet_tsm with non-local module
non_local_cfg = dict(
sub_sample=True,
use_scale=False,
norm_cfg=dict(type='BN3d', requires_grad=True),
mode='embedded_gaussian')
non_local = ((0, 0, 0), (1, 0, 1, 0), (1, 0, 1, 0, 1, 0), (0, 0, 0))
resnet_tsm_nonlocal = ResNetTSM(
50,
non_local=non_local,
non_local_cfg=non_local_cfg,
pretrained='torchvision://resnet50')
resnet_tsm_nonlocal.init_weights()
for layer_name in ['layer2', 'layer3']:
layer = getattr(resnet_tsm_nonlocal, layer_name)
for i, _ in enumerate(layer):
if i % 2 == 0:
assert isinstance(layer[i], NL3DWrapper)
feat = resnet_tsm_nonlocal(self.imgs)
assert feat.shape == torch.Size([8, 2048, 2, 2])
def test_resnet_tsm_full(self):
non_local_cfg = dict(
sub_sample=True,
use_scale=False,
norm_cfg=dict(type='BN3d', requires_grad=True),
mode='embedded_gaussian')
non_local = ((0, 0, 0), (1, 0, 1, 0), (1, 0, 1, 0, 1, 0), (0, 0, 0))
resnet_tsm_50_full = ResNetTSM(
50,
pretrained='torchvision://resnet50',
non_local=non_local,
non_local_cfg=non_local_cfg,
temporal_pool=True)
resnet_tsm_50_full.init_weights()
input_shape = (16, 3, 32, 32)
imgs = generate_backbone_demo_inputs(input_shape)
feat = resnet_tsm_50_full(imgs)
assert feat.shape == torch.Size([8, 2048, 1, 1])
|