|
|
|
from unittest import TestCase
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
from mmaction.evaluation.metrics import RetrievalMetric, RetrievalRecall
|
|
from mmaction.registry import METRICS
|
|
from mmaction.structures import ActionDataSample
|
|
|
|
|
|
def generate_data(num_samples=5, feat_dim=10, random_label=False):
|
|
data_batch = []
|
|
data_samples = []
|
|
for i in range(num_samples):
|
|
if random_label:
|
|
video_feature = torch.randn(feat_dim)
|
|
text_feature = torch.randn(feat_dim)
|
|
else:
|
|
video_feature = torch.randn(feat_dim)
|
|
text_feature = video_feature.clone()
|
|
|
|
data_sample = dict(
|
|
features=dict(
|
|
video_feature=video_feature, text_feature=text_feature))
|
|
data_samples.append(data_sample)
|
|
return data_batch, data_samples
|
|
|
|
|
|
def test_acc_metric():
|
|
with pytest.raises(ValueError):
|
|
RetrievalMetric(metric_list='R100')
|
|
|
|
num_samples = 20
|
|
metric = RetrievalMetric()
|
|
data_batch, predictions = generate_data(
|
|
num_samples=num_samples, random_label=True)
|
|
metric.process(data_batch, predictions)
|
|
eval_results = metric.compute_metrics(metric.results)
|
|
assert 0.0 <= eval_results['R1'] <= eval_results['R5'] <= eval_results[
|
|
'R10'] <= 100.0
|
|
assert 0.0 <= eval_results['MdR'] <= num_samples
|
|
assert 0.0 <= eval_results['MnR'] <= num_samples
|
|
|
|
metric.results.clear()
|
|
|
|
data_batch, predictions = generate_data(
|
|
num_samples=num_samples, random_label=False)
|
|
metric.process(data_batch, predictions)
|
|
eval_results = metric.compute_metrics(metric.results)
|
|
assert eval_results['R1'] == eval_results['R5'] == eval_results[
|
|
'R10'] == 100.0
|
|
assert eval_results['MdR'] == eval_results['MnR'] == 1.0
|
|
|
|
|
|
class TestRetrievalRecall(TestCase):
|
|
|
|
def test_evaluate(self):
|
|
"""Test using the metric in the same way as Evalutor."""
|
|
pred = [
|
|
ActionDataSample().set_pred_score(i).set_gt_label(k).to_dict()
|
|
for i, k in zip([
|
|
torch.tensor([0.7, 0.0, 0.3]),
|
|
torch.tensor([0.5, 0.2, 0.3]),
|
|
torch.tensor([0.4, 0.5, 0.1]),
|
|
torch.tensor([0.0, 0.0, 1.0]),
|
|
torch.tensor([0.0, 0.0, 1.0]),
|
|
torch.tensor([0.0, 0.0, 1.0]),
|
|
], [[0], [0], [1], [2], [2], [0]])
|
|
]
|
|
|
|
|
|
metric = METRICS.build(dict(type='RetrievalRecall', topk=1))
|
|
metric.process(None, pred)
|
|
recall = metric.evaluate(6)
|
|
self.assertIsInstance(recall, dict)
|
|
self.assertAlmostEqual(
|
|
recall['retrieval/Recall@1'], 5 / 6 * 100, places=4)
|
|
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'selected index k'):
|
|
metric = METRICS.build(dict(type='RetrievalRecall', topk=10))
|
|
metric.process(None, pred)
|
|
metric.evaluate(6)
|
|
|
|
with self.assertRaisesRegex(ValueError, '`topk` must be a'):
|
|
METRICS.build(dict(type='RetrievalRecall', topk=-1))
|
|
|
|
|
|
metric = METRICS.build(dict(type='RetrievalRecall', topk=5))
|
|
self.assertEqual(metric.topk, (5, ))
|
|
|
|
|
|
metric = METRICS.build(dict(type='RetrievalRecall', topk=(1, 2, 5)))
|
|
self.assertEqual(metric.topk, (1, 2, 5))
|
|
|
|
def test_calculate(self):
|
|
"""Test using the metric from static method."""
|
|
|
|
|
|
y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
|
|
y_pred = [np.arange(10)] * 2
|
|
|
|
|
|
recall_score = RetrievalRecall.calculate(
|
|
y_pred, y_true, topk=1, pred_indices=True, target_indices=True)
|
|
expect_recall = 50.
|
|
self.assertEqual(recall_score[0].item(), expect_recall)
|
|
|
|
|
|
y_true = torch.Tensor([[1, 0, 1, 0, 0, 1, 0, 0, 1, 1],
|
|
[0, 1, 0, 0, 1, 0, 1, 0, 0, 0]])
|
|
y_pred = np.array([np.linspace(0.95, 0.05, 10)] * 2)
|
|
recall_score = RetrievalRecall.calculate(y_pred, y_true, topk=1)
|
|
expect_recall = 50.
|
|
self.assertEqual(recall_score[0].item(), expect_recall)
|
|
|
|
|
|
y_pred = np.array([np.linspace(0.95, 0.05, 10)] * 2)
|
|
recall_score = RetrievalRecall.calculate(y_pred, y_true, topk=2)
|
|
expect_recall = 100.
|
|
self.assertEqual(recall_score[0].item(), expect_recall)
|
|
|
|
|
|
y_pred = np.array([np.linspace(0.95, 0.05, 10)] * 2)
|
|
recall_score = RetrievalRecall.calculate(y_pred, y_true, topk=(1, 5))
|
|
expect_recalls = [50., 100.]
|
|
self.assertEqual(len(recall_score), len(expect_recalls))
|
|
for i in range(len(expect_recalls)):
|
|
self.assertEqual(recall_score[i].item(), expect_recalls[i])
|
|
|
|
|
|
y_pred = dict()
|
|
y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
|
|
with self.assertRaisesRegex(AssertionError, '`pred` must be Seq'):
|
|
RetrievalRecall.calculate(y_pred, y_true, True, True)
|
|
|
|
|
|
y_true = dict()
|
|
y_pred = [np.arange(10)] * 2
|
|
with self.assertRaisesRegex(AssertionError, '`target` must be Seq'):
|
|
RetrievalRecall.calculate(
|
|
y_pred, y_true, topk=1, pred_indices=True, target_indices=True)
|
|
|
|
|
|
y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
|
|
y_pred = [np.arange(10)] * 3
|
|
with self.assertRaisesRegex(AssertionError, 'Length of `pred`'):
|
|
RetrievalRecall.calculate(
|
|
y_pred, y_true, topk=1, pred_indices=True, target_indices=True)
|
|
|
|
|
|
y_true = [[0, 2, 5, 8, 9], dict()]
|
|
y_pred = [np.arange(10)] * 2
|
|
with self.assertRaisesRegex(AssertionError, '`target` should be'):
|
|
RetrievalRecall.calculate(
|
|
y_pred, y_true, topk=1, pred_indices=True, target_indices=True)
|
|
|
|
|
|
y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
|
|
y_pred = [np.arange(10), dict()]
|
|
with self.assertRaisesRegex(AssertionError, '`pred` should be'):
|
|
RetrievalRecall.calculate(
|
|
y_pred, y_true, topk=1, pred_indices=True, target_indices=True)
|
|
|