File size: 6,613 Bytes
d3dbf03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase

import numpy as np
import pytest
import torch

from mmaction.evaluation.metrics import RetrievalMetric, RetrievalRecall
from mmaction.registry import METRICS
from mmaction.structures import ActionDataSample


def generate_data(num_samples=5, feat_dim=10, random_label=False):
    data_batch = []
    data_samples = []
    for i in range(num_samples):
        if random_label:
            video_feature = torch.randn(feat_dim)
            text_feature = torch.randn(feat_dim)
        else:
            video_feature = torch.randn(feat_dim)
            text_feature = video_feature.clone()

        data_sample = dict(
            features=dict(
                video_feature=video_feature, text_feature=text_feature))
        data_samples.append(data_sample)
    return data_batch, data_samples


def test_acc_metric():
    with pytest.raises(ValueError):
        RetrievalMetric(metric_list='R100')

    num_samples = 20
    metric = RetrievalMetric()
    data_batch, predictions = generate_data(
        num_samples=num_samples, random_label=True)
    metric.process(data_batch, predictions)
    eval_results = metric.compute_metrics(metric.results)
    assert 0.0 <= eval_results['R1'] <= eval_results['R5'] <= eval_results[
        'R10'] <= 100.0
    assert 0.0 <= eval_results['MdR'] <= num_samples
    assert 0.0 <= eval_results['MnR'] <= num_samples

    metric.results.clear()

    data_batch, predictions = generate_data(
        num_samples=num_samples, random_label=False)
    metric.process(data_batch, predictions)
    eval_results = metric.compute_metrics(metric.results)
    assert eval_results['R1'] == eval_results['R5'] == eval_results[
        'R10'] == 100.0
    assert eval_results['MdR'] == eval_results['MnR'] == 1.0


class TestRetrievalRecall(TestCase):

    def test_evaluate(self):
        """Test using the metric in the same way as Evalutor."""
        pred = [
            ActionDataSample().set_pred_score(i).set_gt_label(k).to_dict()
            for i, k in zip([
                torch.tensor([0.7, 0.0, 0.3]),
                torch.tensor([0.5, 0.2, 0.3]),
                torch.tensor([0.4, 0.5, 0.1]),
                torch.tensor([0.0, 0.0, 1.0]),
                torch.tensor([0.0, 0.0, 1.0]),
                torch.tensor([0.0, 0.0, 1.0]),
            ], [[0], [0], [1], [2], [2], [0]])
        ]

        # Test with score (use score instead of label if score exists)
        metric = METRICS.build(dict(type='RetrievalRecall', topk=1))
        metric.process(None, pred)
        recall = metric.evaluate(6)
        self.assertIsInstance(recall, dict)
        self.assertAlmostEqual(
            recall['retrieval/Recall@1'], 5 / 6 * 100, places=4)

        # Test with invalid topk
        with self.assertRaisesRegex(RuntimeError, 'selected index k'):
            metric = METRICS.build(dict(type='RetrievalRecall', topk=10))
            metric.process(None, pred)
            metric.evaluate(6)

        with self.assertRaisesRegex(ValueError, '`topk` must be a'):
            METRICS.build(dict(type='RetrievalRecall', topk=-1))

        # Test initialization
        metric = METRICS.build(dict(type='RetrievalRecall', topk=5))
        self.assertEqual(metric.topk, (5, ))

        # Test initialization
        metric = METRICS.build(dict(type='RetrievalRecall', topk=(1, 2, 5)))
        self.assertEqual(metric.topk, (1, 2, 5))

    def test_calculate(self):
        """Test using the metric from static method."""

        # seq of indices format
        y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
        y_pred = [np.arange(10)] * 2

        # test with average is 'macro'
        recall_score = RetrievalRecall.calculate(
            y_pred, y_true, topk=1, pred_indices=True, target_indices=True)
        expect_recall = 50.
        self.assertEqual(recall_score[0].item(), expect_recall)

        # test with tensor input
        y_true = torch.Tensor([[1, 0, 1, 0, 0, 1, 0, 0, 1, 1],
                               [0, 1, 0, 0, 1, 0, 1, 0, 0, 0]])
        y_pred = np.array([np.linspace(0.95, 0.05, 10)] * 2)
        recall_score = RetrievalRecall.calculate(y_pred, y_true, topk=1)
        expect_recall = 50.
        self.assertEqual(recall_score[0].item(), expect_recall)

        # test with topk is 5
        y_pred = np.array([np.linspace(0.95, 0.05, 10)] * 2)
        recall_score = RetrievalRecall.calculate(y_pred, y_true, topk=2)
        expect_recall = 100.
        self.assertEqual(recall_score[0].item(), expect_recall)

        # test with topk is (1, 5)
        y_pred = np.array([np.linspace(0.95, 0.05, 10)] * 2)
        recall_score = RetrievalRecall.calculate(y_pred, y_true, topk=(1, 5))
        expect_recalls = [50., 100.]
        self.assertEqual(len(recall_score), len(expect_recalls))
        for i in range(len(expect_recalls)):
            self.assertEqual(recall_score[i].item(), expect_recalls[i])

        # Test with invalid pred
        y_pred = dict()
        y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
        with self.assertRaisesRegex(AssertionError, '`pred` must be Seq'):
            RetrievalRecall.calculate(y_pred, y_true, True, True)

        # Test with invalid target
        y_true = dict()
        y_pred = [np.arange(10)] * 2
        with self.assertRaisesRegex(AssertionError, '`target` must be Seq'):
            RetrievalRecall.calculate(
                y_pred, y_true, topk=1, pred_indices=True, target_indices=True)

        # Test with different length `pred` with `target`
        y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
        y_pred = [np.arange(10)] * 3
        with self.assertRaisesRegex(AssertionError, 'Length of `pred`'):
            RetrievalRecall.calculate(
                y_pred, y_true, topk=1, pred_indices=True, target_indices=True)

        # Test with invalid pred
        y_true = [[0, 2, 5, 8, 9], dict()]
        y_pred = [np.arange(10)] * 2
        with self.assertRaisesRegex(AssertionError, '`target` should be'):
            RetrievalRecall.calculate(
                y_pred, y_true, topk=1, pred_indices=True, target_indices=True)

        # Test with invalid target
        y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
        y_pred = [np.arange(10), dict()]
        with self.assertRaisesRegex(AssertionError, '`pred` should be'):
            RetrievalRecall.calculate(
                y_pred, y_true, topk=1, pred_indices=True, target_indices=True)