niobures's picture
mmaction2
d3dbf03 verified
import torch
import torch.nn as nn
from mmengine.model import BaseModule, ModuleList
from mmaction.models.utils import Graph, unit_tcn
from mmaction.registry import MODELS
from .ctrgcn_utils import MSTCN, unit_ctrgcn
class CTRGCNBlock(BaseModule):
def __init__(self,
in_channels,
out_channels,
A,
stride=1,
residual=True,
kernel_size=5,
dilations=[1, 2],
tcn_dropout=0):
super(CTRGCNBlock, self).__init__()
self.gcn1 = unit_ctrgcn(in_channels, out_channels, A)
self.tcn1 = MSTCN(
out_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
dilations=dilations,
residual=False,
tcn_dropout=tcn_dropout)
self.relu = nn.ReLU(inplace=True)
if not residual:
self.residual = lambda x: 0
elif (in_channels == out_channels) and (stride == 1):
self.residual = lambda x: x
else:
self.residual = unit_tcn(
in_channels, out_channels, kernel_size=1, stride=stride)
def forward(self, x):
y = self.relu(self.tcn1(self.gcn1(x)) + self.residual(x))
return y
@MODELS.register_module()
class CTRGCN(BaseModule):
def __init__(self,
graph_cfg,
in_channels=3,
base_channels=64,
num_stages=10,
inflate_stages=[5, 8],
down_stages=[5, 8],
pretrained=None,
num_person=2,
**kwargs):
super(CTRGCN, self).__init__()
self.graph = Graph(**graph_cfg)
A = torch.tensor(
self.graph.A, dtype=torch.float32, requires_grad=False)
self.register_buffer('A', A)
self.num_person = num_person
self.base_channels = base_channels
self.data_bn = nn.BatchNorm1d(num_person * in_channels * A.size(1))
kwargs0 = {k: v for k, v in kwargs.items() if k != 'tcn_dropout'}
modules = [
CTRGCNBlock(
in_channels,
base_channels,
A.clone(),
residual=False,
**kwargs0)
]
for i in range(2, num_stages + 1):
in_channels = base_channels
out_channels = base_channels * (1 + (i in inflate_stages))
stride = 1 + (i in down_stages)
modules.append(
CTRGCNBlock(
base_channels,
out_channels,
A.clone(),
stride=stride,
**kwargs))
base_channels = out_channels
self.net = ModuleList(modules)
def forward(self, x):
N, M, T, V, C = x.size()
x = x.permute(0, 1, 3, 4, 2).contiguous()
x = self.data_bn(x.view(N, M * V * C, T))
x = x.view(N, M, V, C, T).permute(0, 1, 3, 4,
2).contiguous().view(N * M, C, T, V)
for gcn in self.net:
x = gcn(x)
x = x.reshape((N, M) + x.shape[1:])
return x