|
|
|
import pytest
|
|
import torch
|
|
|
|
from mmaction.models import SwinTransformer3D
|
|
from mmaction.testing import generate_backbone_demo_inputs
|
|
|
|
|
|
def test_swin_backbone():
|
|
"""Test swin backbone."""
|
|
with pytest.raises(AssertionError):
|
|
SwinTransformer3D(arch='-t')
|
|
|
|
with pytest.raises(AssertionError):
|
|
SwinTransformer3D(arch={'embed_dims': 96})
|
|
|
|
with pytest.raises(AssertionError):
|
|
SwinTransformer3D(arch={
|
|
'embed_dims': 96,
|
|
'depths': [2, 2, 6],
|
|
'num_heads': [3, 6, 12, 24]
|
|
})
|
|
|
|
with pytest.raises(AssertionError):
|
|
SwinTransformer3D(
|
|
arch={
|
|
'embed_dims': 96,
|
|
'depths': [2, 2, 6, 2, 2],
|
|
'num_heads': [3, 6, 12, 24, 48]
|
|
})
|
|
|
|
with pytest.raises(AssertionError):
|
|
SwinTransformer3D(arch='t', out_indices=(4, ))
|
|
|
|
with pytest.raises(TypeError):
|
|
swin_t = SwinTransformer3D(arch='t', pretrained=[0, 1, 1])
|
|
swin_t.init_weights()
|
|
|
|
with pytest.raises(TypeError):
|
|
swin_t = SwinTransformer3D(arch='t')
|
|
swin_t.init_weights(pretrained=[0, 1, 1])
|
|
|
|
swin_b = SwinTransformer3D(arch='b', pretrained=None, pretrained2d=False)
|
|
swin_b.init_weights()
|
|
swin_b.train()
|
|
|
|
pretrained_url = 'https://download.openmmlab.com/mmaction/v1.0/' \
|
|
'recognition/swin/swin_tiny_patch4_window7_224.pth'
|
|
|
|
swin_t_pre = SwinTransformer3D(
|
|
arch='t', pretrained=pretrained_url, pretrained2d=True)
|
|
swin_t_pre.init_weights()
|
|
swin_t_pre.train()
|
|
|
|
from mmengine.runner.checkpoint import _load_checkpoint
|
|
ckpt_2d = _load_checkpoint(pretrained_url, map_location='cpu')
|
|
state_dict = ckpt_2d['model']
|
|
|
|
patch_embed_weight2d = state_dict['patch_embed.proj.weight'].data
|
|
patch_embed_weight3d = swin_t_pre.patch_embed.proj.weight.data
|
|
assert torch.equal(
|
|
patch_embed_weight3d,
|
|
patch_embed_weight2d.unsqueeze(2).expand_as(patch_embed_weight3d) /
|
|
patch_embed_weight3d.shape[2])
|
|
|
|
norm = swin_t_pre.norm3
|
|
assert torch.equal(norm.weight.data, state_dict['norm.weight'])
|
|
assert torch.equal(norm.bias.data, state_dict['norm.bias'])
|
|
|
|
for name, param in swin_t_pre.named_parameters():
|
|
if 'relative_position_bias_table' in name:
|
|
bias2d = state_dict[name]
|
|
assert torch.equal(
|
|
param.data, bias2d.repeat(2 * swin_t_pre.window_size[0] - 1,
|
|
1))
|
|
|
|
frozen_stages = 1
|
|
swin_t_frozen = SwinTransformer3D(
|
|
arch='t',
|
|
pretrained=None,
|
|
pretrained2d=False,
|
|
frozen_stages=frozen_stages)
|
|
swin_t_frozen.init_weights()
|
|
swin_t_frozen.train()
|
|
for param in swin_t_frozen.patch_embed.parameters():
|
|
assert param.requires_grad is False
|
|
for i in range(frozen_stages):
|
|
layer = swin_t_frozen.layers[i]
|
|
for param in layer.parameters():
|
|
assert param.requires_grad is False
|
|
|
|
input_shape = (1, 3, 6, 64, 64)
|
|
imgs = generate_backbone_demo_inputs(input_shape)
|
|
|
|
feat = swin_t_frozen(imgs)
|
|
assert feat.shape == torch.Size([1, 768, 3, 2, 2])
|
|
|
|
input_shape = (1, 3, 5, 63, 63)
|
|
imgs = generate_backbone_demo_inputs(input_shape)
|
|
feat = swin_t_frozen(imgs)
|
|
assert feat.shape == torch.Size([1, 768, 3, 2, 2])
|
|
|
|
swin_t_all_stages = SwinTransformer3D(arch='t', out_indices=(0, 1, 2, 3))
|
|
feats = swin_t_all_stages(imgs)
|
|
assert feats[0].shape == torch.Size([1, 96, 3, 16, 16])
|
|
assert feats[1].shape == torch.Size([1, 192, 3, 8, 8])
|
|
assert feats[2].shape == torch.Size([1, 384, 3, 4, 4])
|
|
assert feats[3].shape == torch.Size([1, 768, 3, 2, 2])
|
|
|
|
swin_t_all_stages_after_ds = SwinTransformer3D(
|
|
arch='t', out_indices=(0, 1, 2, 3), out_after_downsample=True)
|
|
feats = swin_t_all_stages_after_ds(imgs)
|
|
assert feats[0].shape == torch.Size([1, 192, 3, 8, 8])
|
|
assert feats[1].shape == torch.Size([1, 384, 3, 4, 4])
|
|
assert feats[2].shape == torch.Size([1, 768, 3, 2, 2])
|
|
assert feats[3].shape == torch.Size([1, 768, 3, 2, 2])
|
|
|