niobures's picture
mmaction2
d3dbf03 verified
from typing import Any, Dict, List, Optional, Tuple, Union
import clip
import mmengine
import numpy as np
import torch
import torch.nn.functional as F
from mmengine.dist import all_gather, get_rank
from mmengine.model import BaseModel
from mmengine.structures import LabelData
from mmaction.registry import MODELS
from .adapter import TransformerAdapter
class GatherLayer(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, input: torch.Tensor) -> Tuple[List]:
ctx.save_for_backward(input)
output = all_gather(input)
return tuple(output)
@staticmethod
def backward(ctx: Any, *grads: torch.Tensor) -> torch.Tensor:
input, = ctx.saved_tensors
grad_out = torch.zeros_like(input)
grad_out[:] = grads[get_rank()]
return grad_out
def text_prompt(labels_or_label_file, templates_or_template_file=None):
if isinstance(labels_or_label_file, str):
labels = mmengine.list_from_file(labels_or_label_file)
elif isinstance(labels_or_label_file, list):
labels = labels_or_label_file
else:
raise ValueError(f'`labels_or_label_file` must be `list` or `str`, '
f'but got {type(labels_or_label_file)}')
if templates_or_template_file is None:
templates = [
'a photo of action {}', 'a picture of action {}',
'Human action of {}', '{}, an action', '{} this is an action',
'{}, a video of action', 'Playing action of {}', '{}',
'Playing a kind of action, {}', 'Doing a kind of action, {}',
'Look, the human is {}', 'Can you recognize the action of {}?',
'Video classification of {}', 'A video of {}', 'The man is {}',
'The woman is {}'
]
elif isinstance(templates_or_template_file, str):
templates = mmengine.list_from_file(templates_or_template_file)
elif not mmengine.is_seq_of(templates_or_template_file, str):
raise ValueError(f'`template` must be list of `str`, `str` or `None`, '
f'but got {type(templates_or_template_file)}')
num_prompt = len(templates)
prompt = torch.cat(
[clip.tokenize(t.format(c)) for t in templates for c in labels])
return prompt, num_prompt
@MODELS.register_module()
class ActionClip(BaseModel):
def __init__(self,
clip_arch: str,
num_adapter_segs: int,
num_adapter_layers: int = 6,
to_float32: bool = False,
labels_or_label_file: Optional[Union[List[str], str]] = None,
templates_or_template_file: Optional[Union[List[str],
str]] = None,
data_preprocessor: Optional[Dict] = None,
loss: Dict = dict(type='CrossEntropyLoss', loss_weight=0.5)):
super(ActionClip, self).__init__(data_preprocessor=data_preprocessor)
self.clip = clip.load(clip_arch, device='cpu')[0]
if to_float32:
self.clip.float()
self.adapter = TransformerAdapter(self.clip, num_adapter_segs,
num_adapter_layers)
self.loss = MODELS.build(loss)
if labels_or_label_file is not None:
self.prompt, self.num_prompt = text_prompt(
labels_or_label_file, templates_or_template_file)
def encode_video(self, video):
b, n, c, h, w = video.shape
video = video.view(-1, c, h, w)
frames_features = self.encode_image(video)
frames_features = frames_features.view(b, n, -1)
video_features = self.adapter(frames_features)
return video_features
def encode_image(self, image):
return self.clip.encode_image(image)
def encode_text(self, text):
return self.clip.encode_text(text)
def forward(self,
inputs: torch.Tensor,
data_samples: Optional[List] = None,
mode: str = 'tensor'):
if mode == 'tensor':
return self.encode_video(inputs)
elif mode == 'predict':
assert hasattr(self, 'prompt'),\
'`labels_or_label_file` is required to perform prediction. '
video_features = self.encode_video(inputs)
video_features = video_features / video_features.norm(
dim=-1, keepdim=True)
bsz = len(data_samples)
num_views = video_features.shape[0] // bsz
text_features = self.encode_text(self.prompt.to(inputs.device))
text_features = text_features / text_features.norm(
dim=-1, keepdim=True)
# (bsz*num_views, num_prompt, num_classes) ->
# (bsz, num_views*num_prompt, num_classes)
similarity = (100.0 * video_features @ text_features.T). \
view(bsz, num_views * self.num_prompt, -1)
cls_scores = F.softmax(similarity, dim=2).mean(dim=1)
for data_sample, score in zip(data_samples, cls_scores):
data_sample.pred_scores = LabelData(item=score)
return data_samples
elif mode == 'loss':
video_features = self.encode_video(inputs)
video_features = video_features / video_features.norm(
dim=-1, keepdim=True)
text_id = np.random.randint(
self.num_prompt, size=len(data_samples))
real_labels = [x.gt_labels.item.item() for x in data_samples]
selected_prompt = self.prompt.view(
self.num_prompt, -1,
self.prompt.shape[-1])[text_id, real_labels].to(inputs.device)
text_features = self.encode_text(selected_prompt)
text_features = text_features / text_features.norm(
dim=-1, keepdim=True)
video_features = torch.cat(
GatherLayer.apply(video_features), dim=0)
text_features = torch.cat(GatherLayer.apply(text_features), dim=0)
logit_scale = self.clip.logit_scale.exp()
logits_per_video = logit_scale * video_features @ text_features.t()
logits_per_text = logits_per_video.t()
labels = torch.arange(logits_per_video.shape[0]).to(
logit_scale.device)
sim_loss_v2t = self.loss(logits_per_video, labels)
sim_loss_t2v = self.loss(logits_per_text, labels)
losses = dict()
losses['sim_loss_v2t'] = sim_loss_v2t
losses['sim_loss_t2v'] = sim_loss_t2v
return losses
else:
raise RuntimeError(
f'Invalid mode "{mode}". '
'Only supports `predict`, `loss` and `tensor` mode. ')