File size: 3,049 Bytes
d3dbf03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Copyright (c) OpenMMLab. All rights reserved.
import gc
from typing import Dict, List, Union

from mmengine.runner import EpochBasedTrainLoop
from torch.utils.data import DataLoader

from mmaction.registry import LOOPS


class EpochMultiLoader:
    """Multi loaders based on epoch."""

    def __init__(self, dataloaders: List[DataLoader]):
        self._dataloaders = dataloaders
        self.iter_loaders = [iter(loader) for loader in self._dataloaders]

    @property
    def num_loaders(self):
        """The number of dataloaders."""
        return len(self._dataloaders)

    def __iter__(self):
        """Return self when executing __iter__."""
        return self

    def __next__(self):
        """Get the next iter's data of multiple loaders."""
        data = tuple([next(loader) for loader in self.iter_loaders])
        return data

    def __len__(self):
        """Get the length of loader."""
        return min([len(loader) for loader in self._dataloaders])


@LOOPS.register_module()
class MultiLoaderEpochBasedTrainLoop(EpochBasedTrainLoop):
    """EpochBasedTrainLoop with multiple dataloaders.



    Args:

        runner (Runner): A reference of runner.

        dataloader (Dataloader or Dict): A dataloader object or a dict to

            build a dataloader for training the model.

        other_loaders (List of Dataloader or Dict): A list of other loaders.

            Each item in the list is a dataloader object or a dict to build

            a dataloader.

        max_epochs (int): Total training epochs.

        val_begin (int): The epoch that begins validating. Defaults to 1.

        val_interval (int): Validation interval. Defaults to 1.

    """

    def __init__(self,

                 runner,

                 dataloader: Union[Dict, DataLoader],

                 other_loaders: List[Union[Dict, DataLoader]],

                 max_epochs: int,

                 val_begin: int = 1,

                 val_interval: int = 1) -> None:
        super().__init__(runner, dataloader, max_epochs, val_begin,
                         val_interval)
        multi_loaders = [self.dataloader]
        for loader in other_loaders:
            if isinstance(loader, dict):
                loader = runner.build_dataloader(loader, seed=runner.seed)
            multi_loaders.append(loader)

        self.multi_loaders = multi_loaders

    def run_epoch(self) -> None:
        """Iterate one epoch."""
        self.runner.call_hook('before_train_epoch')
        self.runner.model.train()

        gc.collect()
        for loader in self.multi_loaders:
            if hasattr(loader, 'sampler') and hasattr(loader.sampler,
                                                      'set_epoch'):
                loader.sampler.set_epoch(self._epoch)

        for idx, data_batch in enumerate(EpochMultiLoader(self.multi_loaders)):
            self.run_iter(idx, data_batch)

        self.runner.call_hook('after_train_epoch')
        self._epoch += 1