File size: 3,736 Bytes
d3dbf03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm

from mmaction.models import ResNet2Plus1d
from mmaction.testing import generate_backbone_demo_inputs


def test_resnet2plus1d_backbone():
    # Test r2+1d backbone
    with pytest.raises(AssertionError):
        # r2+1d does not support inflation
        ResNet2Plus1d(50, None, pretrained2d=True)

    with pytest.raises(AssertionError):
        # r2+1d requires conv(2+1)d module
        ResNet2Plus1d(
            50, None, pretrained2d=False, conv_cfg=dict(type='Conv3d'))

    frozen_stages = 1
    r2plus1d_34_frozen = ResNet2Plus1d(
        34,
        None,
        conv_cfg=dict(type='Conv2plus1d'),
        pretrained2d=False,
        frozen_stages=frozen_stages,
        conv1_kernel=(3, 7, 7),
        conv1_stride_t=1,
        pool1_stride_t=1,
        inflate=(1, 1, 1, 1),
        spatial_strides=(1, 2, 2, 2),
        temporal_strides=(1, 2, 2, 2))
    r2plus1d_34_frozen.init_weights()
    r2plus1d_34_frozen.train()
    assert r2plus1d_34_frozen.conv1.conv.bn_s.training is False
    assert r2plus1d_34_frozen.conv1.bn.training is False
    for param in r2plus1d_34_frozen.conv1.parameters():
        assert param.requires_grad is False
    for i in range(1, frozen_stages + 1):
        layer = getattr(r2plus1d_34_frozen, f'layer{i}')
        for mod in layer.modules():
            if isinstance(mod, _BatchNorm):
                assert mod.training is False
        for param in layer.parameters():
            assert param.requires_grad is False
    input_shape = (1, 3, 8, 64, 64)
    imgs = generate_backbone_demo_inputs(input_shape)
    # parrots 3dconv is only implemented on gpu
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            r2plus1d_34_frozen = r2plus1d_34_frozen.cuda()
            imgs_gpu = imgs.cuda()
            feat = r2plus1d_34_frozen(imgs_gpu)
            assert feat.shape == torch.Size([1, 512, 1, 2, 2])
    else:
        feat = r2plus1d_34_frozen(imgs)
        assert feat.shape == torch.Size([1, 512, 1, 2, 2])

    r2plus1d_50_frozen = ResNet2Plus1d(
        50,
        None,
        conv_cfg=dict(type='Conv2plus1d'),
        pretrained2d=False,
        conv1_kernel=(3, 7, 7),
        conv1_stride_t=1,
        pool1_stride_t=1,
        inflate=(1, 1, 1, 1),
        spatial_strides=(1, 2, 2, 2),
        temporal_strides=(1, 2, 2, 2),
        frozen_stages=frozen_stages)
    r2plus1d_50_frozen.init_weights()

    r2plus1d_50_frozen.train()
    assert r2plus1d_50_frozen.conv1.conv.bn_s.training is False
    assert r2plus1d_50_frozen.conv1.bn.training is False
    for param in r2plus1d_50_frozen.conv1.parameters():
        assert param.requires_grad is False
    for i in range(1, frozen_stages + 1):
        layer = getattr(r2plus1d_50_frozen, f'layer{i}')
        for mod in layer.modules():
            if isinstance(mod, _BatchNorm):
                assert mod.training is False
        for param in layer.parameters():
            assert param.requires_grad is False
    input_shape = (1, 3, 8, 64, 64)
    imgs = generate_backbone_demo_inputs(input_shape)

    # parrots 3dconv is only implemented on gpu
    if torch.__version__ == 'parrots':
        if torch.cuda.is_available():
            r2plus1d_50_frozen = r2plus1d_50_frozen.cuda()
            imgs_gpu = imgs.cuda()
            feat = r2plus1d_50_frozen(imgs_gpu)
            assert feat.shape == torch.Size([1, 2048, 1, 2, 2])
    else:
        feat = r2plus1d_50_frozen(imgs)
        assert feat.shape == torch.Size([1, 2048, 1, 2, 2])