File size: 4,414 Bytes
d3dbf03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import mmengine
import numpy as np
import torch
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
def generate_backbone_demo_inputs(input_shape=(1, 3, 64, 64)):
"""Create a superset of inputs needed to run backbone.
Args:
input_shape (tuple): input batch dimensions.
Defaults to ``(1, 3, 64, 64)``.
"""
imgs = np.random.random(input_shape)
imgs = torch.FloatTensor(imgs)
return imgs
# TODO Remove this API
def generate_recognizer_demo_inputs(
input_shape=(1, 3, 3, 224, 224), model_type='2D'):
"""Create a superset of inputs needed to run test or train batches.
Args:
input_shape (tuple): input batch dimensions.
Default: (1, 250, 3, 224, 224).
model_type (str): Model type for data generation, from {'2D', '3D'}.
Default:'2D'
"""
if len(input_shape) == 5:
(N, L, _, _, _) = input_shape
elif len(input_shape) == 6:
(N, M, _, L, _, _) = input_shape
imgs = np.random.random(input_shape)
if model_type == '2D' or model_type == 'skeleton':
gt_labels = torch.LongTensor([2] * N)
elif model_type == '3D':
gt_labels = torch.LongTensor([2] * M)
elif model_type == 'audio':
gt_labels = torch.LongTensor([2] * L)
else:
raise ValueError(f'Data type {model_type} is not available')
inputs = {'imgs': torch.FloatTensor(imgs), 'gt_labels': gt_labels}
return inputs
def generate_detector_demo_inputs(
input_shape=(1, 3, 4, 224, 224), num_classes=81, train=True,
device='cpu'):
num_samples = input_shape[0]
if not train:
assert num_samples == 1
def random_box(n):
box = torch.rand(n, 4) * 0.5
box[:, 2:] += 0.5
box[:, 0::2] *= input_shape[3]
box[:, 1::2] *= input_shape[4]
if device == 'cuda':
box = box.cuda()
return box
def random_label(n):
label = torch.randn(n, num_classes)
label = (label > 0.8).type(torch.float32)
label[:, 0] = 0
if device == 'cuda':
label = label.cuda()
return label
img = torch.FloatTensor(np.random.random(input_shape))
if device == 'cuda':
img = img.cuda()
proposals = [random_box(2) for i in range(num_samples)]
gt_bboxes = [random_box(2) for i in range(num_samples)]
gt_labels = [random_label(2) for i in range(num_samples)]
img_metas = [dict(img_shape=input_shape[-2:]) for i in range(num_samples)]
if train:
return dict(
img=img,
proposals=proposals,
gt_bboxes=gt_bboxes,
gt_labels=gt_labels,
img_metas=img_metas)
return dict(img=[img], proposals=[proposals], img_metas=[img_metas])
def get_cfg(config_type, fname):
"""Grab configs necessary to create a recognizer.
These are deep copied to allow for safe modification of parameters without
influencing other tests.
"""
config_types = ('recognition', 'recognition_audio', 'localization',
'detection', 'skeleton', 'retrieval')
assert config_type in config_types
repo_dpath = osp.dirname(osp.dirname(osp.dirname(__file__)))
config_dpath = osp.join(repo_dpath, 'configs/' + config_type)
config_fpath = osp.join(config_dpath, fname)
if not osp.exists(config_dpath):
raise Exception('Cannot find config path')
config = mmengine.Config.fromfile(config_fpath)
return config
def get_recognizer_cfg(fname):
return get_cfg('recognition', fname)
def get_audio_recognizer_cfg(fname):
return get_cfg('recognition_audio', fname)
def get_localizer_cfg(fname):
return get_cfg('localization', fname)
def get_detector_cfg(fname):
return get_cfg('detection', fname)
def get_skeletongcn_cfg(fname):
return get_cfg('skeleton', fname)
def get_similarity_cfg(fname):
return get_cfg('retrieval', fname)
|