|
|
|
import os
|
|
import tempfile
|
|
|
|
import torch
|
|
from mmengine.runner import load_checkpoint, save_checkpoint
|
|
from mmengine.runner.checkpoint import _load_checkpoint_with_prefix
|
|
|
|
from mmaction.models.backbones.mobileone_tsm import MobileOneTSM
|
|
from mmaction.testing import generate_backbone_demo_inputs
|
|
|
|
|
|
def test_mobileone_tsm_backbone():
|
|
"""Test MobileOne TSM backbone."""
|
|
|
|
from mmpretrain.models.backbones.mobileone import MobileOneBlock
|
|
|
|
from mmaction.models.backbones.resnet_tsm import TemporalShift
|
|
|
|
model = MobileOneTSM('s0', pretrained2d=False)
|
|
model.init_weights()
|
|
for cur_module in model.modules():
|
|
if isinstance(cur_module, TemporalShift):
|
|
|
|
assert isinstance(cur_module.net, MobileOneBlock)
|
|
assert cur_module.num_segments == model.num_segments
|
|
assert cur_module.shift_div == model.shift_div
|
|
|
|
inputs = generate_backbone_demo_inputs((8, 3, 64, 64))
|
|
|
|
feat = model(inputs)
|
|
assert feat.shape == torch.Size([8, 1024, 2, 2])
|
|
|
|
model = MobileOneTSM('s1', pretrained2d=False)
|
|
feat = model(inputs)
|
|
assert feat.shape == torch.Size([8, 1280, 2, 2])
|
|
|
|
model = MobileOneTSM('s2', pretrained2d=False)
|
|
feat = model(inputs)
|
|
assert feat.shape == torch.Size([8, 2048, 2, 2])
|
|
|
|
model = MobileOneTSM('s3', pretrained2d=False)
|
|
feat = model(inputs)
|
|
assert feat.shape == torch.Size([8, 2048, 2, 2])
|
|
|
|
model = MobileOneTSM('s4', pretrained2d=False)
|
|
feat = model(inputs)
|
|
assert feat.shape == torch.Size([8, 2048, 2, 2])
|
|
|
|
|
|
def test_mobileone_init_weight():
|
|
checkpoint = ('https://download.openmmlab.com/mmclassification/v0'
|
|
'/mobileone/mobileone-s0_8xb32_in1k_20221110-0bc94952.pth')
|
|
|
|
model = MobileOneTSM(
|
|
arch='s0',
|
|
init_cfg=dict(
|
|
type='Pretrained', checkpoint=checkpoint, prefix='backbone'))
|
|
model.init_weights()
|
|
ori_ckpt = _load_checkpoint_with_prefix(
|
|
'backbone', model.init_cfg['checkpoint'], map_location='cpu')
|
|
for name, param in model.named_parameters():
|
|
ori_name = name.replace('.net', '')
|
|
assert torch.allclose(param, ori_ckpt[ori_name]), \
|
|
f'layer {name} fail to load from pretrained checkpoint'
|
|
|
|
|
|
def test_load_deploy_mobileone():
|
|
|
|
model = MobileOneTSM('s0', pretrained2d=False)
|
|
inputs = generate_backbone_demo_inputs((8, 3, 64, 64))
|
|
tmpdir = tempfile.gettempdir()
|
|
ckpt_path = os.path.join(tmpdir, 'ckpt.pth')
|
|
model.switch_to_deploy()
|
|
model.eval()
|
|
outputs = model(inputs)
|
|
|
|
model_deploy = MobileOneTSM('s0', pretrained2d=False, deploy=True)
|
|
save_checkpoint(model.state_dict(), ckpt_path)
|
|
load_checkpoint(model_deploy, ckpt_path)
|
|
|
|
outputs_load = model_deploy(inputs)
|
|
for feat, feat_load in zip(outputs, outputs_load):
|
|
assert torch.allclose(feat, feat_load)
|
|
os.remove(ckpt_path)
|
|
|