mmaction2 / demo /demo_skeleton.py
niobures's picture
mmaction2
d3dbf03 verified
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import tempfile
import cv2
import mmcv
import mmengine
import torch
from mmengine import DictAction
from mmengine.utils import track_iter_progress
from mmaction.apis import (detection_inference, inference_skeleton,
init_recognizer, pose_inference)
from mmaction.registry import VISUALIZERS
from mmaction.utils import frame_extract
try:
import moviepy.editor as mpy
except ImportError:
raise ImportError('Please install moviepy to enable output file')
FONTFACE = cv2.FONT_HERSHEY_DUPLEX
FONTSCALE = 0.75
FONTCOLOR = (255, 255, 255) # BGR, white
THICKNESS = 1
LINETYPE = 1
def parse_args():
parser = argparse.ArgumentParser(description='MMAction2 demo')
parser.add_argument('video', help='video file/url')
parser.add_argument('out_filename', help='output filename')
parser.add_argument(
'--config',
default=('configs/skeleton/posec3d/'
'slowonly_r50_8xb16-u48-240e_ntu60-xsub-keypoint.py'),
help='skeleton model config file path')
parser.add_argument(
'--checkpoint',
default=('https://download.openmmlab.com/mmaction/skeleton/posec3d/'
'slowonly_r50_u48_240e_ntu60_xsub_keypoint/'
'slowonly_r50_u48_240e_ntu60_xsub_keypoint-f3adabf1.pth'),
help='skeleton model checkpoint file/url')
parser.add_argument(
'--det-config',
default='demo/demo_configs/faster-rcnn_r50_fpn_2x_coco_infer.py',
help='human detection config file path (from mmdet)')
parser.add_argument(
'--det-checkpoint',
default=('http://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/'
'faster_rcnn_r50_fpn_2x_coco/'
'faster_rcnn_r50_fpn_2x_coco_'
'bbox_mAP-0.384_20200504_210434-a5d8aa15.pth'),
help='human detection checkpoint file/url')
parser.add_argument(
'--det-score-thr',
type=float,
default=0.9,
help='the threshold of human detection score')
parser.add_argument(
'--det-cat-id',
type=int,
default=0,
help='the category id for human detection')
parser.add_argument(
'--pose-config',
default='demo/demo_configs/'
'td-hm_hrnet-w32_8xb64-210e_coco-256x192_infer.py',
help='human pose estimation config file path (from mmpose)')
parser.add_argument(
'--pose-checkpoint',
default=('https://download.openmmlab.com/mmpose/top_down/hrnet/'
'hrnet_w32_coco_256x192-c78dce93_20200708.pth'),
help='human pose estimation checkpoint file/url')
parser.add_argument(
'--label-map',
default='tools/data/skeleton/label_map_ntu60.txt',
help='label map file')
parser.add_argument(
'--device', type=str, default='cuda:0', help='CPU/CUDA device option')
parser.add_argument(
'--short-side',
type=int,
default=480,
help='specify the short-side length of the image')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
default={},
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. For example, '
"'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
args = parser.parse_args()
return args
def visualize(args, frames, data_samples, action_label):
pose_config = mmengine.Config.fromfile(args.pose_config)
visualizer = VISUALIZERS.build(pose_config.visualizer)
visualizer.set_dataset_meta(data_samples[0].dataset_meta)
vis_frames = []
print('Drawing skeleton for each frame')
for d, f in track_iter_progress(list(zip(data_samples, frames))):
f = mmcv.imconvert(f, 'bgr', 'rgb')
visualizer.add_datasample(
'result',
f,
data_sample=d,
draw_gt=False,
draw_heatmap=False,
draw_bbox=True,
show=False,
wait_time=0,
out_file=None,
kpt_thr=0.3)
vis_frame = visualizer.get_image()
cv2.putText(vis_frame, action_label, (10, 30), FONTFACE, FONTSCALE,
FONTCOLOR, THICKNESS, LINETYPE)
vis_frames.append(vis_frame)
vid = mpy.ImageSequenceClip(vis_frames, fps=24)
vid.write_videofile(args.out_filename, remove_temp=True)
def main():
args = parse_args()
tmp_dir = tempfile.TemporaryDirectory()
frame_paths, frames = frame_extract(args.video, args.short_side,
tmp_dir.name)
h, w, _ = frames[0].shape
# Get Human detection results.
det_results, _ = detection_inference(args.det_config, args.det_checkpoint,
frame_paths, args.det_score_thr,
args.det_cat_id, args.device)
torch.cuda.empty_cache()
# Get Pose estimation results.
pose_results, pose_data_samples = pose_inference(args.pose_config,
args.pose_checkpoint,
frame_paths, det_results,
args.device)
torch.cuda.empty_cache()
config = mmengine.Config.fromfile(args.config)
config.merge_from_dict(args.cfg_options)
model = init_recognizer(config, args.checkpoint, args.device)
result = inference_skeleton(model, pose_results, (h, w))
max_pred_index = result.pred_score.argmax().item()
label_map = [x.strip() for x in open(args.label_map).readlines()]
action_label = label_map[max_pred_index]
visualize(args, frames, pose_data_samples, action_label)
tmp_dir.cleanup()
if __name__ == '__main__':
main()