# Copyright (c) OpenMMLab. All rights reserved. import argparse from operator import itemgetter import torch from mmengine import Config, DictAction from mmaction.apis import inference_recognizer, init_recognizer def parse_args(): parser = argparse.ArgumentParser(description='MMAction2 demo') parser.add_argument('config', help='test config file path') parser.add_argument('checkpoint', help='checkpoint file/url') parser.add_argument('audio', help='audio file') parser.add_argument('label', help='label file') parser.add_argument( '--cfg-options', nargs='+', action=DictAction, default={}, help='override some settings in the used config, the key-value pair ' 'in xxx=yyy format will be merged into config file. For example, ' "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'") parser.add_argument( '--device', type=str, default='cuda:0', help='CPU/CUDA device option') args = parser.parse_args() return args def main(): args = parse_args() device = torch.device(args.device) cfg = Config.fromfile(args.config) cfg.merge_from_dict(args.cfg_options) model = init_recognizer(cfg, args.checkpoint, device=device) if not args.audio.endswith('.npy'): raise NotImplementedError('Demo works on extracted audio features') pred_result = inference_recognizer(model, args.audio) pred_scores = pred_result.pred_score.tolist() score_tuples = tuple(zip(range(len(pred_scores)), pred_scores)) score_sorted = sorted(score_tuples, key=itemgetter(1), reverse=True) top5_label = score_sorted[:5] labels = open(args.label).readlines() labels = [x.strip() for x in labels] results = [(labels[k[0]], k[1]) for k in top5_label] print('The top-5 labels with corresponding scores are:') for result in results: print(f'{result[0]}: ', result[1]) if __name__ == '__main__': main()