# Copyright (c) OpenMMLab. All rights reserved. import pytest import torch from mmaction.models.common import Conv2plus1d def test_conv2plus1d(): with pytest.raises(AssertionError): # Length of kernel size, stride and padding must be the same Conv2plus1d(3, 8, (2, 2)) conv_2plus1d = Conv2plus1d(3, 8, 2) conv_2plus1d.init_weights() assert torch.equal(conv_2plus1d.bn_s.weight, torch.ones_like(conv_2plus1d.bn_s.weight)) assert torch.equal(conv_2plus1d.bn_s.bias, torch.zeros_like(conv_2plus1d.bn_s.bias)) x = torch.rand(1, 3, 8, 256, 256) output = conv_2plus1d(x) assert output.shape == torch.Size([1, 8, 7, 255, 255])