mmaction2 / tests /models /backbones /test_uniformer.py
niobures's picture
mmaction2
d3dbf03 verified
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmaction.models import UniFormer
from mmaction.testing import generate_backbone_demo_inputs
def test_uniformer_backbone():
"""Test uniformer backbone."""
input_shape = (1, 3, 16, 64, 64)
imgs = generate_backbone_demo_inputs(input_shape)
model = UniFormer(
depth=[3, 4, 8, 3],
embed_dim=[64, 128, 320, 512],
head_dim=64,
drop_path_rate=0.1)
model.init_weights()
model.eval()
assert model(imgs).shape == torch.Size([1, 512, 8, 2, 2])