|
|
|
import pytest
|
|
import torch
|
|
|
|
from mmaction.models import RGBPoseConv3D
|
|
from mmaction.testing import generate_backbone_demo_inputs
|
|
|
|
|
|
def test_rgbposeconv3d():
|
|
"""Test RGBPoseConv3D backbone."""
|
|
|
|
with pytest.raises(AssertionError):
|
|
RGBPoseConv3D(pose_drop_path=1.1, rgb_drop_path=1.1)
|
|
|
|
rgbposec3d = RGBPoseConv3D()
|
|
rgbposec3d.init_weights()
|
|
rgbposec3d.train()
|
|
|
|
imgs_shape = (1, 3, 8, 224, 224)
|
|
heatmap_imgs_shape = (1, 17, 32, 56, 56)
|
|
imgs = generate_backbone_demo_inputs(imgs_shape)
|
|
heatmap_imgs = generate_backbone_demo_inputs(heatmap_imgs_shape)
|
|
|
|
(x_rgb, x_pose) = rgbposec3d(imgs, heatmap_imgs)
|
|
|
|
assert x_rgb.shape == torch.Size([1, 2048, 8, 7, 7])
|
|
assert x_pose.shape == torch.Size([1, 512, 32, 7, 7])
|
|
|