|
|
|
import argparse
|
|
import copy as cp
|
|
import tempfile
|
|
import warnings
|
|
|
|
import cv2
|
|
import mmcv
|
|
import mmengine
|
|
import numpy as np
|
|
import torch
|
|
from mmengine import DictAction
|
|
from mmengine.structures import InstanceData
|
|
|
|
from mmaction.apis import (detection_inference, inference_recognizer,
|
|
inference_skeleton, init_recognizer, pose_inference)
|
|
from mmaction.registry import VISUALIZERS
|
|
from mmaction.structures import ActionDataSample
|
|
from mmaction.utils import frame_extract
|
|
|
|
try:
|
|
from mmdet.apis import init_detector
|
|
except (ImportError, ModuleNotFoundError):
|
|
warnings.warn('Failed to import `init_detector` form `mmdet.apis`. '
|
|
'These apis are required in skeleton-based applications! ')
|
|
|
|
try:
|
|
import moviepy.editor as mpy
|
|
except ImportError:
|
|
raise ImportError('Please install moviepy to enable output file')
|
|
|
|
FONTFACE = cv2.FONT_HERSHEY_DUPLEX
|
|
FONTSCALE = 0.5
|
|
FONTCOLOR = (255, 255, 255)
|
|
MSGCOLOR = (128, 128, 128)
|
|
THICKNESS = 1
|
|
LINETYPE = 1
|
|
|
|
|
|
def hex2color(h):
|
|
"""Convert the 6-digit hex string to tuple of 3 int value (RGB)"""
|
|
return (int(h[:2], 16), int(h[2:4], 16), int(h[4:], 16))
|
|
|
|
|
|
PLATEBLUE = '03045e-023e8a-0077b6-0096c7-00b4d8-48cae4'
|
|
PLATEBLUE = PLATEBLUE.split('-')
|
|
PLATEBLUE = [hex2color(h) for h in PLATEBLUE]
|
|
PLATEGREEN = '004b23-006400-007200-008000-38b000-70e000'
|
|
PLATEGREEN = PLATEGREEN.split('-')
|
|
PLATEGREEN = [hex2color(h) for h in PLATEGREEN]
|
|
|
|
|
|
def visualize(args,
|
|
frames,
|
|
annotations,
|
|
pose_data_samples,
|
|
action_result,
|
|
plate=PLATEBLUE,
|
|
max_num=5):
|
|
"""Visualize frames with predicted annotations.
|
|
|
|
Args:
|
|
frames (list[np.ndarray]): Frames for visualization, note that
|
|
len(frames) % len(annotations) should be 0.
|
|
annotations (list[list[tuple]]): The predicted spatio-temporal
|
|
detection results.
|
|
pose_data_samples (list[list[PoseDataSample]): The pose results.
|
|
action_result (str): The predicted action recognition results.
|
|
pose_model (nn.Module): The constructed pose model.
|
|
plate (str): The plate used for visualization. Default: PLATEBLUE.
|
|
max_num (int): Max number of labels to visualize for a person box.
|
|
Default: 5.
|
|
|
|
Returns:
|
|
list[np.ndarray]: Visualized frames.
|
|
"""
|
|
|
|
assert max_num + 1 <= len(plate)
|
|
frames_ = cp.deepcopy(frames)
|
|
frames_ = [mmcv.imconvert(f, 'bgr', 'rgb') for f in frames_]
|
|
nf, na = len(frames), len(annotations)
|
|
assert nf % na == 0
|
|
nfpa = len(frames) // len(annotations)
|
|
anno = None
|
|
h, w, _ = frames[0].shape
|
|
scale_ratio = np.array([w, h, w, h])
|
|
|
|
|
|
if pose_data_samples:
|
|
pose_config = mmengine.Config.fromfile(args.pose_config)
|
|
visualizer = VISUALIZERS.build(pose_config.visualizer)
|
|
visualizer.set_dataset_meta(pose_data_samples[0].dataset_meta)
|
|
for i, (d, f) in enumerate(zip(pose_data_samples, frames_)):
|
|
visualizer.add_datasample(
|
|
'result',
|
|
f,
|
|
data_sample=d,
|
|
draw_gt=False,
|
|
draw_heatmap=False,
|
|
draw_bbox=True,
|
|
show=False,
|
|
wait_time=0,
|
|
out_file=None,
|
|
kpt_thr=0.3)
|
|
frames_[i] = visualizer.get_image()
|
|
cv2.putText(frames_[i], action_result, (10, 30), FONTFACE,
|
|
FONTSCALE, FONTCOLOR, THICKNESS, LINETYPE)
|
|
|
|
for i in range(na):
|
|
anno = annotations[i]
|
|
if anno is None:
|
|
continue
|
|
for j in range(nfpa):
|
|
ind = i * nfpa + j
|
|
frame = frames_[ind]
|
|
|
|
|
|
cv2.putText(frame, action_result, (10, 30), FONTFACE, FONTSCALE,
|
|
FONTCOLOR, THICKNESS, LINETYPE)
|
|
|
|
|
|
for ann in anno:
|
|
box = ann[0]
|
|
label = ann[1]
|
|
if not len(label):
|
|
continue
|
|
score = ann[2]
|
|
box = (box * scale_ratio).astype(np.int64)
|
|
st, ed = tuple(box[:2]), tuple(box[2:])
|
|
if not pose_data_samples:
|
|
cv2.rectangle(frame, st, ed, plate[0], 2)
|
|
|
|
for k, lb in enumerate(label):
|
|
if k >= max_num:
|
|
break
|
|
text = abbrev(lb)
|
|
text = ': '.join([text, f'{score[k]:.3f}'])
|
|
location = (0 + st[0], 18 + k * 18 + st[1])
|
|
textsize = cv2.getTextSize(text, FONTFACE, FONTSCALE,
|
|
THICKNESS)[0]
|
|
textwidth = textsize[0]
|
|
diag0 = (location[0] + textwidth, location[1] - 14)
|
|
diag1 = (location[0], location[1] + 2)
|
|
cv2.rectangle(frame, diag0, diag1, plate[k + 1], -1)
|
|
cv2.putText(frame, text, location, FONTFACE, FONTSCALE,
|
|
FONTCOLOR, THICKNESS, LINETYPE)
|
|
|
|
return frames_
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='MMAction2 demo')
|
|
parser.add_argument(
|
|
'--rgb-stdet-config',
|
|
default=(
|
|
'configs/detection/slowonly/'
|
|
'slowonly_kinetics400-pretrained-r101_8xb16-8x8x1-20e_ava21-rgb.py'
|
|
),
|
|
help='rgb-based spatio temporal detection config file path')
|
|
parser.add_argument(
|
|
'--rgb-stdet-checkpoint',
|
|
default=('https://download.openmmlab.com/mmaction/detection/ava/'
|
|
'slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb/'
|
|
'slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb'
|
|
'_20201217-16378594.pth'),
|
|
help='rgb-based spatio temporal detection checkpoint file/url')
|
|
parser.add_argument(
|
|
'--skeleton-stdet-checkpoint',
|
|
default=('https://download.openmmlab.com/mmaction/skeleton/posec3d/'
|
|
'posec3d_ava.pth'),
|
|
help='skeleton-based spatio temporal detection checkpoint file/url')
|
|
parser.add_argument(
|
|
'--det-config',
|
|
default='demo/demo_configs/faster-rcnn_r50_fpn_2x_coco_infer.py',
|
|
help='human detection config file path (from mmdet)')
|
|
parser.add_argument(
|
|
'--det-checkpoint',
|
|
default=('http://download.openmmlab.com/mmdetection/v2.0/'
|
|
'faster_rcnn/faster_rcnn_r50_fpn_2x_coco/'
|
|
'faster_rcnn_r50_fpn_2x_coco_'
|
|
'bbox_mAP-0.384_20200504_210434-a5d8aa15.pth'),
|
|
help='human detection checkpoint file/url')
|
|
parser.add_argument(
|
|
'--pose-config',
|
|
default='demo/demo_configs'
|
|
'/td-hm_hrnet-w32_8xb64-210e_coco-256x192_infer.py',
|
|
help='human pose estimation config file path (from mmpose)')
|
|
parser.add_argument(
|
|
'--pose-checkpoint',
|
|
default=('https://download.openmmlab.com/mmpose/top_down/hrnet/'
|
|
'hrnet_w32_coco_256x192-c78dce93_20200708.pth'),
|
|
help='human pose estimation checkpoint file/url')
|
|
parser.add_argument(
|
|
'--skeleton-config',
|
|
default='configs/skeleton/posec3d'
|
|
'/slowonly_r50_8xb16-u48-240e_ntu60-xsub-keypoint.py',
|
|
help='skeleton-based action recognition config file path')
|
|
parser.add_argument(
|
|
'--skeleton-checkpoint',
|
|
default='https://download.openmmlab.com/mmaction/skeleton/posec3d/'
|
|
'posec3d_k400.pth',
|
|
help='skeleton-based action recognition checkpoint file/url')
|
|
parser.add_argument(
|
|
'--rgb-config',
|
|
default='configs/recognition/tsn/'
|
|
'tsn_imagenet-pretrained-r50_8xb32-1x1x3-100e_kinetics400-rgb.py',
|
|
help='rgb-based action recognition config file path')
|
|
parser.add_argument(
|
|
'--rgb-checkpoint',
|
|
default='https://download.openmmlab.com/mmaction/recognition/'
|
|
'tsn/tsn_r50_1x1x3_100e_kinetics400_rgb/'
|
|
'tsn_r50_1x1x3_100e_kinetics400_rgb_20200614-e508be42.pth',
|
|
help='rgb-based action recognition checkpoint file/url')
|
|
parser.add_argument(
|
|
'--use-skeleton-stdet',
|
|
action='store_true',
|
|
help='use skeleton-based spatio temporal detection method')
|
|
parser.add_argument(
|
|
'--use-skeleton-recog',
|
|
action='store_true',
|
|
help='use skeleton-based action recognition method')
|
|
parser.add_argument(
|
|
'--det-score-thr',
|
|
type=float,
|
|
default=0.9,
|
|
help='the threshold of human detection score')
|
|
parser.add_argument(
|
|
'--action-score-thr',
|
|
type=float,
|
|
default=0.4,
|
|
help='the threshold of action prediction score')
|
|
parser.add_argument(
|
|
'--video',
|
|
default='demo/test_video_structuralize.mp4',
|
|
help='video file/url')
|
|
parser.add_argument(
|
|
'--label-map-stdet',
|
|
default='tools/data/ava/label_map.txt',
|
|
help='label map file for spatio-temporal action detection')
|
|
parser.add_argument(
|
|
'--label-map',
|
|
default='tools/data/kinetics/label_map_k400.txt',
|
|
help='label map file for action recognition')
|
|
parser.add_argument(
|
|
'--device', type=str, default='cuda:0', help='CPU/CUDA device option')
|
|
parser.add_argument(
|
|
'--out-filename',
|
|
default='demo/test_stdet_recognition_output.mp4',
|
|
help='output filename')
|
|
parser.add_argument(
|
|
'--predict-stepsize',
|
|
default=8,
|
|
type=int,
|
|
help='give out a spatio-temporal detection prediction per n frames')
|
|
parser.add_argument(
|
|
'--output-stepsize',
|
|
default=1,
|
|
type=int,
|
|
help=('show one frame per n frames in the demo, we should have: '
|
|
'predict_stepsize % output_stepsize == 0'))
|
|
parser.add_argument(
|
|
'--output-fps',
|
|
default=24,
|
|
type=int,
|
|
help='the fps of demo video output')
|
|
parser.add_argument(
|
|
'--cfg-options',
|
|
nargs='+',
|
|
action=DictAction,
|
|
default={},
|
|
help='override some settings in the used config, the key-value pair '
|
|
'in xxx=yyy format will be merged into config file. For example, '
|
|
"'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def load_label_map(file_path):
|
|
"""Load Label Map.
|
|
|
|
Args:
|
|
file_path (str): The file path of label map.
|
|
|
|
Returns:
|
|
dict: The label map (int -> label name).
|
|
"""
|
|
lines = open(file_path).readlines()
|
|
lines = [x.strip().split(': ') for x in lines]
|
|
return {int(x[0]): x[1] for x in lines}
|
|
|
|
|
|
def abbrev(name):
|
|
"""Get the abbreviation of label name:
|
|
|
|
'take (an object) from (a person)' -> 'take ... from ...'
|
|
"""
|
|
while name.find('(') != -1:
|
|
st, ed = name.find('('), name.find(')')
|
|
name = name[:st] + '...' + name[ed + 1:]
|
|
return name
|
|
|
|
|
|
def pack_result(human_detection, result, img_h, img_w):
|
|
"""Short summary.
|
|
|
|
Args:
|
|
human_detection (np.ndarray): Human detection result.
|
|
result (type): The predicted label of each human proposal.
|
|
img_h (int): The image height.
|
|
img_w (int): The image width.
|
|
|
|
Returns:
|
|
tuple: Tuple of human proposal, label name and label score.
|
|
"""
|
|
human_detection[:, 0::2] /= img_w
|
|
human_detection[:, 1::2] /= img_h
|
|
results = []
|
|
if result is None:
|
|
return None
|
|
for prop, res in zip(human_detection, result):
|
|
res.sort(key=lambda x: -x[1])
|
|
results.append(
|
|
(prop.data.cpu().numpy(), [x[0] for x in res], [x[1]
|
|
for x in res]))
|
|
return results
|
|
|
|
|
|
def expand_bbox(bbox, h, w, ratio=1.25):
|
|
x1, y1, x2, y2 = bbox
|
|
center_x = (x1 + x2) // 2
|
|
center_y = (y1 + y2) // 2
|
|
width = x2 - x1
|
|
height = y2 - y1
|
|
|
|
square_l = max(width, height)
|
|
new_width = new_height = square_l * ratio
|
|
|
|
new_x1 = max(0, int(center_x - new_width / 2))
|
|
new_x2 = min(int(center_x + new_width / 2), w)
|
|
new_y1 = max(0, int(center_y - new_height / 2))
|
|
new_y2 = min(int(center_y + new_height / 2), h)
|
|
return (new_x1, new_y1, new_x2, new_y2)
|
|
|
|
|
|
def cal_iou(box1, box2):
|
|
xmin1, ymin1, xmax1, ymax1 = box1
|
|
xmin2, ymin2, xmax2, ymax2 = box2
|
|
|
|
s1 = (xmax1 - xmin1) * (ymax1 - ymin1)
|
|
s2 = (xmax2 - xmin2) * (ymax2 - ymin2)
|
|
|
|
xmin = max(xmin1, xmin2)
|
|
ymin = max(ymin1, ymin2)
|
|
xmax = min(xmax1, xmax2)
|
|
ymax = min(ymax1, ymax2)
|
|
|
|
w = max(0, xmax - xmin)
|
|
h = max(0, ymax - ymin)
|
|
intersect = w * h
|
|
union = s1 + s2 - intersect
|
|
iou = intersect / union
|
|
|
|
return iou
|
|
|
|
|
|
def skeleton_based_action_recognition(args, pose_results, h, w):
|
|
label_map = [x.strip() for x in open(args.label_map).readlines()]
|
|
num_class = len(label_map)
|
|
|
|
skeleton_config = mmengine.Config.fromfile(args.skeleton_config)
|
|
skeleton_config.model.cls_head.num_classes = num_class
|
|
|
|
skeleton_model = init_recognizer(
|
|
skeleton_config, args.skeleton_checkpoint, device=args.device)
|
|
result = inference_skeleton(skeleton_model, pose_results, (h, w))
|
|
action_idx = result.pred_score.argmax().item()
|
|
return label_map[action_idx]
|
|
|
|
|
|
def rgb_based_action_recognition(args):
|
|
rgb_config = mmengine.Config.fromfile(args.rgb_config)
|
|
rgb_config.model.backbone.pretrained = None
|
|
rgb_model = init_recognizer(rgb_config, args.rgb_checkpoint, args.device)
|
|
action_results = inference_recognizer(rgb_model, args.video)
|
|
rgb_action_result = action_results.pred_score.argmax().item()
|
|
label_map = [x.strip() for x in open(args.label_map).readlines()]
|
|
return label_map[rgb_action_result]
|
|
|
|
|
|
def skeleton_based_stdet(args, label_map, human_detections, pose_results,
|
|
num_frame, clip_len, frame_interval, h, w):
|
|
window_size = clip_len * frame_interval
|
|
assert clip_len % 2 == 0, 'We would like to have an even clip_len'
|
|
timestamps = np.arange(window_size // 2, num_frame + 1 - window_size // 2,
|
|
args.predict_stepsize)
|
|
|
|
skeleton_config = mmengine.Config.fromfile(args.skeleton_config)
|
|
num_class = max(label_map.keys()) + 1
|
|
skeleton_config.model.cls_head.num_classes = num_class
|
|
skeleton_stdet_model = init_recognizer(skeleton_config,
|
|
args.skeleton_stdet_checkpoint,
|
|
args.device)
|
|
|
|
skeleton_predictions = []
|
|
|
|
print('Performing SpatioTemporal Action Detection for each clip')
|
|
prog_bar = mmengine.ProgressBar(len(timestamps))
|
|
for timestamp in timestamps:
|
|
proposal = human_detections[timestamp - 1]
|
|
if proposal.shape[0] == 0:
|
|
skeleton_predictions.append(None)
|
|
continue
|
|
|
|
start_frame = timestamp - (clip_len // 2 - 1) * frame_interval
|
|
frame_inds = start_frame + np.arange(0, window_size, frame_interval)
|
|
frame_inds = list(frame_inds - 1)
|
|
num_frame = len(frame_inds)
|
|
|
|
pose_result = [pose_results[ind] for ind in frame_inds]
|
|
|
|
skeleton_prediction = []
|
|
for i in range(proposal.shape[0]):
|
|
skeleton_prediction.append([])
|
|
|
|
fake_anno = dict(
|
|
frame_dict='',
|
|
label=-1,
|
|
img_shape=(h, w),
|
|
origin_shape=(h, w),
|
|
start_index=0,
|
|
modality='Pose',
|
|
total_frames=num_frame)
|
|
num_person = 1
|
|
|
|
num_keypoint = 17
|
|
keypoint = np.zeros(
|
|
(num_person, num_frame, num_keypoint, 2))
|
|
keypoint_score = np.zeros(
|
|
(num_person, num_frame, num_keypoint))
|
|
|
|
|
|
person_bbox = proposal[i][:4]
|
|
area = expand_bbox(person_bbox, h, w)
|
|
|
|
for j, poses in enumerate(pose_result):
|
|
max_iou = float('-inf')
|
|
index = -1
|
|
if len(poses['keypoints']) == 0:
|
|
continue
|
|
for k, bbox in enumerate(poses['bboxes']):
|
|
iou = cal_iou(bbox, area)
|
|
if max_iou < iou:
|
|
index = k
|
|
max_iou = iou
|
|
keypoint[0, j] = poses['keypoints'][index]
|
|
keypoint_score[0, j] = poses['keypoint_scores'][index]
|
|
|
|
fake_anno['keypoint'] = keypoint
|
|
fake_anno['keypoint_score'] = keypoint_score
|
|
|
|
output = inference_recognizer(skeleton_stdet_model, fake_anno)
|
|
|
|
score = output.pred_score.tolist()
|
|
for k in range(len(score)):
|
|
if k not in label_map:
|
|
continue
|
|
if score[k] > args.action_score_thr:
|
|
skeleton_prediction[i].append((label_map[k], score[k]))
|
|
|
|
skeleton_predictions.append(skeleton_prediction)
|
|
prog_bar.update()
|
|
|
|
return timestamps, skeleton_predictions
|
|
|
|
|
|
def rgb_based_stdet(args, frames, label_map, human_detections, w, h, new_w,
|
|
new_h, w_ratio, h_ratio):
|
|
|
|
rgb_stdet_config = mmengine.Config.fromfile(args.rgb_stdet_config)
|
|
rgb_stdet_config.merge_from_dict(args.cfg_options)
|
|
|
|
val_pipeline = rgb_stdet_config.val_pipeline
|
|
sampler = [x for x in val_pipeline if x['type'] == 'SampleAVAFrames'][0]
|
|
clip_len, frame_interval = sampler['clip_len'], sampler['frame_interval']
|
|
assert clip_len % 2 == 0, 'We would like to have an even clip_len'
|
|
|
|
window_size = clip_len * frame_interval
|
|
num_frame = len(frames)
|
|
|
|
timestamps = np.arange(window_size // 2, num_frame + 1 - window_size // 2,
|
|
args.predict_stepsize)
|
|
|
|
|
|
img_norm_cfg = dict(
|
|
mean=np.array(rgb_stdet_config.model.data_preprocessor.mean),
|
|
std=np.array(rgb_stdet_config.model.data_preprocessor.std),
|
|
to_rgb=False)
|
|
|
|
|
|
try:
|
|
|
|
|
|
rgb_stdet_config['model']['test_cfg']['rcnn'] = dict(action_thr=0)
|
|
except KeyError:
|
|
pass
|
|
|
|
rgb_stdet_config.model.backbone.pretrained = None
|
|
rgb_stdet_model = init_detector(
|
|
rgb_stdet_config, args.rgb_stdet_checkpoint, device=args.device)
|
|
|
|
predictions = []
|
|
|
|
print('Performing SpatioTemporal Action Detection for each clip')
|
|
prog_bar = mmengine.ProgressBar(len(timestamps))
|
|
|
|
for timestamp in timestamps:
|
|
proposal = human_detections[timestamp - 1]
|
|
if proposal.shape[0] == 0:
|
|
predictions.append(None)
|
|
continue
|
|
|
|
start_frame = timestamp - (clip_len // 2 - 1) * frame_interval
|
|
frame_inds = start_frame + np.arange(0, window_size, frame_interval)
|
|
frame_inds = list(frame_inds - 1)
|
|
|
|
imgs = [frames[ind].astype(np.float32) for ind in frame_inds]
|
|
_ = [mmcv.imnormalize_(img, **img_norm_cfg) for img in imgs]
|
|
|
|
input_array = np.stack(imgs).transpose((3, 0, 1, 2))[np.newaxis]
|
|
input_tensor = torch.from_numpy(input_array).to(args.device)
|
|
|
|
datasample = ActionDataSample()
|
|
datasample.proposals = InstanceData(bboxes=proposal)
|
|
datasample.set_metainfo(dict(img_shape=(new_h, new_w)))
|
|
with torch.no_grad():
|
|
result = rgb_stdet_model(
|
|
input_tensor, [datasample], mode='predict')
|
|
scores = result[0].pred_instances.scores
|
|
prediction = []
|
|
|
|
for i in range(proposal.shape[0]):
|
|
prediction.append([])
|
|
|
|
for i in range(scores.shape[1]):
|
|
if i not in label_map:
|
|
continue
|
|
for j in range(proposal.shape[0]):
|
|
if scores[j, i] > args.action_score_thr:
|
|
prediction[j].append((label_map[i], scores[j,
|
|
i].item()))
|
|
predictions.append(prediction)
|
|
prog_bar.update()
|
|
|
|
return timestamps, predictions
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
tmp_dir = tempfile.TemporaryDirectory()
|
|
frame_paths, original_frames = frame_extract(
|
|
args.video, out_dir=tmp_dir.name)
|
|
num_frame = len(frame_paths)
|
|
h, w, _ = original_frames[0].shape
|
|
|
|
|
|
human_detections, _ = detection_inference(
|
|
args.det_config,
|
|
args.det_checkpoint,
|
|
frame_paths,
|
|
args.det_score_thr,
|
|
device=args.device)
|
|
pose_datasample = None
|
|
if args.use_skeleton_recog or args.use_skeleton_stdet:
|
|
pose_results, pose_datasample = pose_inference(
|
|
args.pose_config,
|
|
args.pose_checkpoint,
|
|
frame_paths,
|
|
human_detections,
|
|
device=args.device)
|
|
|
|
|
|
new_w, new_h = mmcv.rescale_size((w, h), (256, np.Inf))
|
|
frames = [mmcv.imresize(img, (new_w, new_h)) for img in original_frames]
|
|
w_ratio, h_ratio = new_w / w, new_h / h
|
|
|
|
|
|
stdet_label_map = load_label_map(args.label_map_stdet)
|
|
rgb_stdet_config = mmengine.Config.fromfile(args.rgb_stdet_config)
|
|
rgb_stdet_config.merge_from_dict(args.cfg_options)
|
|
try:
|
|
if rgb_stdet_config['data']['train']['custom_classes'] is not None:
|
|
stdet_label_map = {
|
|
id + 1: stdet_label_map[cls]
|
|
for id, cls in enumerate(rgb_stdet_config['data']['train']
|
|
['custom_classes'])
|
|
}
|
|
except KeyError:
|
|
pass
|
|
|
|
action_result = None
|
|
if args.use_skeleton_recog:
|
|
print('Use skeleton-based recognition')
|
|
action_result = skeleton_based_action_recognition(
|
|
args, pose_results, h, w)
|
|
else:
|
|
print('Use rgb-based recognition')
|
|
action_result = rgb_based_action_recognition(args)
|
|
|
|
stdet_preds = None
|
|
if args.use_skeleton_stdet:
|
|
print('Use skeleton-based SpatioTemporal Action Detection')
|
|
clip_len, frame_interval = 30, 1
|
|
timestamps, stdet_preds = skeleton_based_stdet(args, stdet_label_map,
|
|
human_detections,
|
|
pose_results, num_frame,
|
|
clip_len,
|
|
frame_interval, h, w)
|
|
for i in range(len(human_detections)):
|
|
det = human_detections[i]
|
|
det[:, 0:4:2] *= w_ratio
|
|
det[:, 1:4:2] *= h_ratio
|
|
human_detections[i] = torch.from_numpy(det[:, :4]).to(args.device)
|
|
|
|
else:
|
|
print('Use rgb-based SpatioTemporal Action Detection')
|
|
for i in range(len(human_detections)):
|
|
det = human_detections[i]
|
|
det[:, 0:4:2] *= w_ratio
|
|
det[:, 1:4:2] *= h_ratio
|
|
human_detections[i] = torch.from_numpy(det[:, :4]).to(args.device)
|
|
timestamps, stdet_preds = rgb_based_stdet(args, frames,
|
|
stdet_label_map,
|
|
human_detections, w, h,
|
|
new_w, new_h, w_ratio,
|
|
h_ratio)
|
|
|
|
stdet_results = []
|
|
for timestamp, prediction in zip(timestamps, stdet_preds):
|
|
human_detection = human_detections[timestamp - 1]
|
|
stdet_results.append(
|
|
pack_result(human_detection, prediction, new_h, new_w))
|
|
|
|
def dense_timestamps(timestamps, n):
|
|
"""Make it nx frames."""
|
|
old_frame_interval = (timestamps[1] - timestamps[0])
|
|
start = timestamps[0] - old_frame_interval / n * (n - 1) / 2
|
|
new_frame_inds = np.arange(
|
|
len(timestamps) * n) * old_frame_interval / n + start
|
|
return new_frame_inds.astype(np.int64)
|
|
|
|
dense_n = int(args.predict_stepsize / args.output_stepsize)
|
|
output_timestamps = dense_timestamps(timestamps, dense_n)
|
|
frames = [
|
|
cv2.imread(frame_paths[timestamp - 1])
|
|
for timestamp in output_timestamps
|
|
]
|
|
|
|
if args.use_skeleton_recog or args.use_skeleton_stdet:
|
|
pose_datasample = [
|
|
pose_datasample[timestamp - 1] for timestamp in output_timestamps
|
|
]
|
|
|
|
vis_frames = visualize(args, frames, stdet_results, pose_datasample,
|
|
action_result)
|
|
vid = mpy.ImageSequenceClip(vis_frames, fps=args.output_fps)
|
|
vid.write_videofile(args.out_filename)
|
|
|
|
tmp_dir.cleanup()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|