File size: 2,376 Bytes
d3dbf03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import MagicMock

import torch

from mmaction.registry import MODELS
from mmaction.structures import ActionDataSample
from mmaction.testing import get_skeletongcn_cfg
from mmaction.utils import register_all_modules


def train_test_step(cfg, input_shape):
    recognizer = MODELS.build(cfg.model)
    num_classes = cfg.model.cls_head.num_classes
    data_batch = {
        'inputs': [torch.randn(input_shape)],
        'data_samples': [ActionDataSample().set_gt_label(2)]
    }

    # test train_step
    optim_wrapper = MagicMock()
    loss_vars = recognizer.train_step(data_batch, optim_wrapper)
    assert 'loss' in loss_vars
    assert 'loss_cls' in loss_vars
    optim_wrapper.update_params.assert_called_once()

    # test test_step
    with torch.no_grad():
        predictions = recognizer.test_step(data_batch)
    score = predictions[0].pred_score
    assert len(predictions) == 1
    assert score.shape == torch.Size([num_classes])
    assert torch.min(score) >= 0
    assert torch.max(score) <= 1

    # test when average_clips is None
    recognizer.cls_head.average_clips = None
    num_clips = 3
    input_shape = (num_clips, *input_shape[1:])
    data_batch['inputs'] = [torch.randn(input_shape)]
    with torch.no_grad():
        predictions = recognizer.test_step(data_batch)
    score = predictions[0].pred_score
    assert len(predictions) == 1
    assert score.shape == torch.Size([num_clips, num_classes])

    return loss_vars, predictions


def test_stgcn():
    register_all_modules()
    config = get_skeletongcn_cfg(
        'stgcn/stgcn_8xb16-joint-u100-80e_ntu60-xsub-keypoint-2d.py')
    input_shape = (1, 2, 30, 17, 3)  # N M T V C
    train_test_step(config, input_shape=input_shape)


def test_agcn():
    register_all_modules()
    config = get_skeletongcn_cfg(
        '2s-agcn/2s-agcn_8xb16-joint-u100-80e_ntu60-xsub-keypoint-2d.py')
    input_shape = (1, 2, 30, 17, 3)  # N M T V C
    train_test_step(config, input_shape=input_shape)


def test_stgcn_plusplus():
    register_all_modules()
    config = get_skeletongcn_cfg(
        'stgcnpp/stgcnpp_8xb16-joint-u100-80e_ntu60-xsub-keypoint-2d.py')
    input_shape = (1, 2, 30, 17, 3)  # N M T V C
    train_test_step(config, input_shape=input_shape)