File size: 600 Bytes
d3dbf03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmaction.models import ResNetAudio
from mmaction.testing import generate_backbone_demo_inputs
from mmaction.utils import register_all_modules


def test_resnet_audio_backbone():
    """Test ResNetAudio backbone."""
    input_shape = (1, 1, 16, 16)
    spec = generate_backbone_demo_inputs(input_shape)
    # inference
    register_all_modules()
    audioonly = ResNetAudio(50, None)
    audioonly.init_weights()
    audioonly.train()
    feat = audioonly(spec)
    assert feat.shape == torch.Size([1, 1024, 2, 2])