mmaction2 / tests /models /backbones /test_mobileone_tsm.py
niobures's picture
mmaction2
d3dbf03 verified
# Copyright (c) OpenMMLab. All rights reserved.
import os
import tempfile
import torch
from mmengine.runner import load_checkpoint, save_checkpoint
from mmengine.runner.checkpoint import _load_checkpoint_with_prefix
from mmaction.models.backbones.mobileone_tsm import MobileOneTSM
from mmaction.testing import generate_backbone_demo_inputs
def test_mobileone_tsm_backbone():
"""Test MobileOne TSM backbone."""
from mmpretrain.models.backbones.mobileone import MobileOneBlock
from mmaction.models.backbones.resnet_tsm import TemporalShift
model = MobileOneTSM('s0', pretrained2d=False)
model.init_weights()
for cur_module in model.modules():
if isinstance(cur_module, TemporalShift):
# TemporalShift is a wrapper of MobileOneBlock
assert isinstance(cur_module.net, MobileOneBlock)
assert cur_module.num_segments == model.num_segments
assert cur_module.shift_div == model.shift_div
inputs = generate_backbone_demo_inputs((8, 3, 64, 64))
feat = model(inputs)
assert feat.shape == torch.Size([8, 1024, 2, 2])
model = MobileOneTSM('s1', pretrained2d=False)
feat = model(inputs)
assert feat.shape == torch.Size([8, 1280, 2, 2])
model = MobileOneTSM('s2', pretrained2d=False)
feat = model(inputs)
assert feat.shape == torch.Size([8, 2048, 2, 2])
model = MobileOneTSM('s3', pretrained2d=False)
feat = model(inputs)
assert feat.shape == torch.Size([8, 2048, 2, 2])
model = MobileOneTSM('s4', pretrained2d=False)
feat = model(inputs)
assert feat.shape == torch.Size([8, 2048, 2, 2])
def test_mobileone_init_weight():
checkpoint = ('https://download.openmmlab.com/mmclassification/v0'
'/mobileone/mobileone-s0_8xb32_in1k_20221110-0bc94952.pth')
# ckpt = torch.load(checkpoint)['state_dict']
model = MobileOneTSM(
arch='s0',
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint, prefix='backbone'))
model.init_weights()
ori_ckpt = _load_checkpoint_with_prefix(
'backbone', model.init_cfg['checkpoint'], map_location='cpu')
for name, param in model.named_parameters():
ori_name = name.replace('.net', '')
assert torch.allclose(param, ori_ckpt[ori_name]), \
f'layer {name} fail to load from pretrained checkpoint'
def test_load_deploy_mobileone():
# Test output before and load from deploy checkpoint
model = MobileOneTSM('s0', pretrained2d=False)
inputs = generate_backbone_demo_inputs((8, 3, 64, 64))
tmpdir = tempfile.gettempdir()
ckpt_path = os.path.join(tmpdir, 'ckpt.pth')
model.switch_to_deploy()
model.eval()
outputs = model(inputs)
model_deploy = MobileOneTSM('s0', pretrained2d=False, deploy=True)
save_checkpoint(model.state_dict(), ckpt_path)
load_checkpoint(model_deploy, ckpt_path)
outputs_load = model_deploy(inputs)
for feat, feat_load in zip(outputs, outputs_load):
assert torch.allclose(feat, feat_load)
os.remove(ckpt_path)