import torch import torch.nn as nn from mmengine.model import BaseModule, ModuleList from mmaction.models.utils import Graph, unit_tcn from mmaction.registry import MODELS from .ctrgcn_utils import MSTCN, unit_ctrgcn class CTRGCNBlock(BaseModule): def __init__(self, in_channels, out_channels, A, stride=1, residual=True, kernel_size=5, dilations=[1, 2], tcn_dropout=0): super(CTRGCNBlock, self).__init__() self.gcn1 = unit_ctrgcn(in_channels, out_channels, A) self.tcn1 = MSTCN( out_channels, out_channels, kernel_size=kernel_size, stride=stride, dilations=dilations, residual=False, tcn_dropout=tcn_dropout) self.relu = nn.ReLU(inplace=True) if not residual: self.residual = lambda x: 0 elif (in_channels == out_channels) and (stride == 1): self.residual = lambda x: x else: self.residual = unit_tcn( in_channels, out_channels, kernel_size=1, stride=stride) def forward(self, x): y = self.relu(self.tcn1(self.gcn1(x)) + self.residual(x)) return y @MODELS.register_module() class CTRGCN(BaseModule): def __init__(self, graph_cfg, in_channels=3, base_channels=64, num_stages=10, inflate_stages=[5, 8], down_stages=[5, 8], pretrained=None, num_person=2, **kwargs): super(CTRGCN, self).__init__() self.graph = Graph(**graph_cfg) A = torch.tensor( self.graph.A, dtype=torch.float32, requires_grad=False) self.register_buffer('A', A) self.num_person = num_person self.base_channels = base_channels self.data_bn = nn.BatchNorm1d(num_person * in_channels * A.size(1)) kwargs0 = {k: v for k, v in kwargs.items() if k != 'tcn_dropout'} modules = [ CTRGCNBlock( in_channels, base_channels, A.clone(), residual=False, **kwargs0) ] for i in range(2, num_stages + 1): in_channels = base_channels out_channels = base_channels * (1 + (i in inflate_stages)) stride = 1 + (i in down_stages) modules.append( CTRGCNBlock( base_channels, out_channels, A.clone(), stride=stride, **kwargs)) base_channels = out_channels self.net = ModuleList(modules) def forward(self, x): N, M, T, V, C = x.size() x = x.permute(0, 1, 3, 4, 2).contiguous() x = self.data_bn(x.view(N, M * V * C, T)) x = x.view(N, M, V, C, T).permute(0, 1, 3, 4, 2).contiguous().view(N * M, C, T, V) for gcn in self.net: x = gcn(x) x = x.reshape((N, M) + x.shape[1:]) return x