# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
from mmaction.models import ResNetAudio | |
from mmaction.testing import generate_backbone_demo_inputs | |
from mmaction.utils import register_all_modules | |
def test_resnet_audio_backbone(): | |
"""Test ResNetAudio backbone.""" | |
input_shape = (1, 1, 16, 16) | |
spec = generate_backbone_demo_inputs(input_shape) | |
# inference | |
register_all_modules() | |
audioonly = ResNetAudio(50, None) | |
audioonly.init_weights() | |
audioonly.train() | |
feat = audioonly(spec) | |
assert feat.shape == torch.Size([1, 1024, 2, 2]) | |