File size: 1,565 Bytes
d3dbf03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
from mmaction.datasets import PoseDataset
from .base import BaseTestDataset
class TestPoseDataset(BaseTestDataset):
def test_pose_dataset(self):
ann_file = self.pose_ann_file
data_prefix = dict(video='root')
dataset = PoseDataset(
ann_file=ann_file,
pipeline=[],
split='train',
box_thr=0.5,
data_prefix=data_prefix)
assert len(dataset) == 100
item = dataset[0]
assert item['frame_dir'].startswith(data_prefix['video'])
dataset = PoseDataset(
ann_file=ann_file,
pipeline=[],
split='train',
valid_ratio=0.2,
box_thr=0.9)
assert len(dataset) == 84
for item in dataset:
assert np.all(item['box_score'][item['anno_inds']] >= 0.9)
assert item['valid'][0.9] / item['total_frames'] >= 0.2
dataset = PoseDataset(
ann_file=ann_file,
pipeline=[],
split='train',
valid_ratio=0.3,
box_thr=0.7)
assert len(dataset) == 87
for item in dataset:
assert np.all(item['box_score'][item['anno_inds']] >= 0.7)
assert item['valid'][0.7] / item['total_frames'] >= 0.3
with pytest.raises(AssertionError):
dataset = PoseDataset(
ann_file=ann_file, pipeline=[], valid_ratio=0.2, box_thr=0.55)
|