|
|
|
import os.path as osp
|
|
|
|
import mmengine
|
|
import numpy as np
|
|
import torch
|
|
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
|
|
|
|
|
def check_norm_state(modules, train_state):
|
|
"""Check if norm layer is in correct train state."""
|
|
for mod in modules:
|
|
if isinstance(mod, _BatchNorm):
|
|
if mod.training != train_state:
|
|
return False
|
|
return True
|
|
|
|
|
|
def generate_backbone_demo_inputs(input_shape=(1, 3, 64, 64)):
|
|
"""Create a superset of inputs needed to run backbone.
|
|
|
|
Args:
|
|
input_shape (tuple): input batch dimensions.
|
|
Defaults to ``(1, 3, 64, 64)``.
|
|
"""
|
|
imgs = np.random.random(input_shape)
|
|
imgs = torch.FloatTensor(imgs)
|
|
|
|
return imgs
|
|
|
|
|
|
|
|
def generate_recognizer_demo_inputs(
|
|
input_shape=(1, 3, 3, 224, 224), model_type='2D'):
|
|
"""Create a superset of inputs needed to run test or train batches.
|
|
|
|
Args:
|
|
input_shape (tuple): input batch dimensions.
|
|
Default: (1, 250, 3, 224, 224).
|
|
model_type (str): Model type for data generation, from {'2D', '3D'}.
|
|
Default:'2D'
|
|
"""
|
|
if len(input_shape) == 5:
|
|
(N, L, _, _, _) = input_shape
|
|
elif len(input_shape) == 6:
|
|
(N, M, _, L, _, _) = input_shape
|
|
|
|
imgs = np.random.random(input_shape)
|
|
|
|
if model_type == '2D' or model_type == 'skeleton':
|
|
gt_labels = torch.LongTensor([2] * N)
|
|
elif model_type == '3D':
|
|
gt_labels = torch.LongTensor([2] * M)
|
|
elif model_type == 'audio':
|
|
gt_labels = torch.LongTensor([2] * L)
|
|
else:
|
|
raise ValueError(f'Data type {model_type} is not available')
|
|
|
|
inputs = {'imgs': torch.FloatTensor(imgs), 'gt_labels': gt_labels}
|
|
return inputs
|
|
|
|
|
|
def generate_detector_demo_inputs(
|
|
input_shape=(1, 3, 4, 224, 224), num_classes=81, train=True,
|
|
device='cpu'):
|
|
num_samples = input_shape[0]
|
|
if not train:
|
|
assert num_samples == 1
|
|
|
|
def random_box(n):
|
|
box = torch.rand(n, 4) * 0.5
|
|
box[:, 2:] += 0.5
|
|
box[:, 0::2] *= input_shape[3]
|
|
box[:, 1::2] *= input_shape[4]
|
|
if device == 'cuda':
|
|
box = box.cuda()
|
|
return box
|
|
|
|
def random_label(n):
|
|
label = torch.randn(n, num_classes)
|
|
label = (label > 0.8).type(torch.float32)
|
|
label[:, 0] = 0
|
|
if device == 'cuda':
|
|
label = label.cuda()
|
|
return label
|
|
|
|
img = torch.FloatTensor(np.random.random(input_shape))
|
|
if device == 'cuda':
|
|
img = img.cuda()
|
|
|
|
proposals = [random_box(2) for i in range(num_samples)]
|
|
gt_bboxes = [random_box(2) for i in range(num_samples)]
|
|
gt_labels = [random_label(2) for i in range(num_samples)]
|
|
img_metas = [dict(img_shape=input_shape[-2:]) for i in range(num_samples)]
|
|
|
|
if train:
|
|
return dict(
|
|
img=img,
|
|
proposals=proposals,
|
|
gt_bboxes=gt_bboxes,
|
|
gt_labels=gt_labels,
|
|
img_metas=img_metas)
|
|
|
|
return dict(img=[img], proposals=[proposals], img_metas=[img_metas])
|
|
|
|
|
|
def get_cfg(config_type, fname):
|
|
"""Grab configs necessary to create a recognizer.
|
|
|
|
These are deep copied to allow for safe modification of parameters without
|
|
influencing other tests.
|
|
"""
|
|
config_types = ('recognition', 'recognition_audio', 'localization',
|
|
'detection', 'skeleton', 'retrieval')
|
|
assert config_type in config_types
|
|
|
|
repo_dpath = osp.dirname(osp.dirname(osp.dirname(__file__)))
|
|
config_dpath = osp.join(repo_dpath, 'configs/' + config_type)
|
|
config_fpath = osp.join(config_dpath, fname)
|
|
if not osp.exists(config_dpath):
|
|
raise Exception('Cannot find config path')
|
|
config = mmengine.Config.fromfile(config_fpath)
|
|
return config
|
|
|
|
|
|
def get_recognizer_cfg(fname):
|
|
return get_cfg('recognition', fname)
|
|
|
|
|
|
def get_audio_recognizer_cfg(fname):
|
|
return get_cfg('recognition_audio', fname)
|
|
|
|
|
|
def get_localizer_cfg(fname):
|
|
return get_cfg('localization', fname)
|
|
|
|
|
|
def get_detector_cfg(fname):
|
|
return get_cfg('detection', fname)
|
|
|
|
|
|
def get_skeletongcn_cfg(fname):
|
|
return get_cfg('skeleton', fname)
|
|
|
|
|
|
def get_similarity_cfg(fname):
|
|
return get_cfg('retrieval', fname)
|
|
|