niobures's picture
mmaction2
d3dbf03 verified
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmaction.models import AAGCN
from mmaction.utils import register_all_modules
def test_aagcn_backbone():
"""Test AAGCN backbone."""
register_all_modules()
mode = 'spatial'
batch_size, num_person, num_frames = 2, 2, 150
# openpose-18 layout
num_joints = 18
model = AAGCN(graph_cfg=dict(layout='openpose', mode=mode))
model.init_weights()
inputs = torch.randn(batch_size, num_person, num_frames, num_joints, 3)
output = model(inputs)
assert output.shape == torch.Size([2, 2, 256, 38, 18])
# nturgb+d layout
num_joints = 25
model = AAGCN(graph_cfg=dict(layout='nturgb+d', mode=mode))
model.init_weights()
inputs = torch.randn(batch_size, num_person, num_frames, num_joints, 3)
output = model(inputs)
assert output.shape == torch.Size([2, 2, 256, 38, 25])
# coco layout
num_joints = 17
model = AAGCN(graph_cfg=dict(layout='coco', mode=mode))
model.init_weights()
inputs = torch.randn(batch_size, num_person, num_frames, num_joints, 3)
output = model(inputs)
assert output.shape == torch.Size([2, 2, 256, 38, 17])
# custom settings
# disable the attention module to degenerate AAGCN to AGCN
model = AAGCN(
graph_cfg=dict(layout='coco', mode=mode), gcn_attention=False)
model.init_weights()
output = model(inputs)
assert output.shape == torch.Size([2, 2, 256, 38, 17])