File size: 5,693 Bytes
d3dbf03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from unittest import TestCase

import pytest
import torch
import torch.nn as nn

from mmaction.models import ResNetTSM
from mmaction.models.backbones.resnet import Bottleneck
from mmaction.models.backbones.resnet_tsm import NL3DWrapper, TemporalShift
from mmaction.testing import generate_backbone_demo_inputs


class Test_ResNet_TSM(TestCase):

    def setUp(self):
        input_shape = (8, 3, 64, 64)
        self.imgs = generate_backbone_demo_inputs(input_shape)

    def test_init(self):
        with pytest.raises(NotImplementedError):
            # shift_place must be block or blockres
            resnet_tsm_50_block = ResNetTSM(50, shift_place='Block')
            resnet_tsm_50_block.init_weights()

    def test_init_from_scratch(self):
        resnet_tsm_50 = ResNetTSM(50, pretrained=None, pretrained2d=False)
        resnet_tsm_50.init_weights()

    def test_resnet_tsm_temporal_shift_blockres(self):
        # resnet_tsm with depth 50
        resnet_tsm_50 = ResNetTSM(50, pretrained='torchvision://resnet50')
        resnet_tsm_50.init_weights()
        for layer_name in resnet_tsm_50.res_layers:
            layer = getattr(resnet_tsm_50, layer_name)
            blocks = list(layer.children())
            for block in blocks:
                assert isinstance(block.conv1.conv, TemporalShift)
                assert block.conv1.conv.num_segments == resnet_tsm_50.num_segments  # noqa: E501
                assert block.conv1.conv.shift_div == resnet_tsm_50.shift_div
                assert isinstance(block.conv1.conv.net, nn.Conv2d)
        feat = resnet_tsm_50(self.imgs)
        assert feat.shape == torch.Size([8, 2048, 2, 2])

    def test_resnet_tsm_temporal_shift_block(self):
        # resnet_tsm with depth 50, no pretrained, shift_place is block
        resnet_tsm_50_block = ResNetTSM(
            50, shift_place='block', pretrained='torchvision://resnet50')
        resnet_tsm_50_block.init_weights()
        for layer_name in resnet_tsm_50_block.res_layers:
            layer = getattr(resnet_tsm_50_block, layer_name)
            blocks = list(layer.children())
            for block in blocks:
                assert isinstance(block, TemporalShift)
                assert block.num_segments == resnet_tsm_50_block.num_segments
                assert block.num_segments == resnet_tsm_50_block.num_segments
                assert block.shift_div == resnet_tsm_50_block.shift_div
                assert isinstance(block.net, Bottleneck)

    def test_resnet_tsm_temporal_pool(self):
        # resnet_tsm with depth 50, no pretrained, use temporal_pool
        resnet_tsm_50_temporal_pool = ResNetTSM(
            50, temporal_pool=True, pretrained='torchvision://resnet50')
        resnet_tsm_50_temporal_pool.init_weights()
        for layer_name in resnet_tsm_50_temporal_pool.res_layers:
            layer = getattr(resnet_tsm_50_temporal_pool, layer_name)
            blocks = list(layer.children())

            if layer_name == 'layer2':
                assert len(blocks) == 2
                assert isinstance(blocks[1], nn.MaxPool3d)
                blocks = copy.deepcopy(blocks[0])

            for block in blocks:
                assert isinstance(block.conv1.conv, TemporalShift)
                if layer_name == 'layer1':
                    assert block.conv1.conv.num_segments == \
                        resnet_tsm_50_temporal_pool.num_segments
                else:
                    assert block.conv1.conv.num_segments == \
                        resnet_tsm_50_temporal_pool.num_segments // 2
                assert block.conv1.conv.shift_div == resnet_tsm_50_temporal_pool.shift_div  # noqa: E501
                assert isinstance(block.conv1.conv.net, nn.Conv2d)

        feat = resnet_tsm_50_temporal_pool(self.imgs)
        assert feat.shape == torch.Size([4, 2048, 2, 2])

    def test_resnet_tsm_non_local(self):
        # resnet_tsm with non-local module
        non_local_cfg = dict(
            sub_sample=True,
            use_scale=False,
            norm_cfg=dict(type='BN3d', requires_grad=True),
            mode='embedded_gaussian')
        non_local = ((0, 0, 0), (1, 0, 1, 0), (1, 0, 1, 0, 1, 0), (0, 0, 0))
        resnet_tsm_nonlocal = ResNetTSM(
            50,
            non_local=non_local,
            non_local_cfg=non_local_cfg,
            pretrained='torchvision://resnet50')
        resnet_tsm_nonlocal.init_weights()
        for layer_name in ['layer2', 'layer3']:
            layer = getattr(resnet_tsm_nonlocal, layer_name)
            for i, _ in enumerate(layer):
                if i % 2 == 0:
                    assert isinstance(layer[i], NL3DWrapper)

        feat = resnet_tsm_nonlocal(self.imgs)
        assert feat.shape == torch.Size([8, 2048, 2, 2])

    def test_resnet_tsm_full(self):
        non_local_cfg = dict(
            sub_sample=True,
            use_scale=False,
            norm_cfg=dict(type='BN3d', requires_grad=True),
            mode='embedded_gaussian')
        non_local = ((0, 0, 0), (1, 0, 1, 0), (1, 0, 1, 0, 1, 0), (0, 0, 0))
        resnet_tsm_50_full = ResNetTSM(
            50,
            pretrained='torchvision://resnet50',
            non_local=non_local,
            non_local_cfg=non_local_cfg,
            temporal_pool=True)
        resnet_tsm_50_full.init_weights()

        input_shape = (16, 3, 32, 32)
        imgs = generate_backbone_demo_inputs(input_shape)
        feat = resnet_tsm_50_full(imgs)
        assert feat.shape == torch.Size([8, 2048, 1, 1])