File size: 2,719 Bytes
d3dbf03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# Copyright (c) OpenMMLab. All rights reserved.
from functools import reduce
from operator import mul
from typing import List

import torch.nn as nn
from mmengine.logging import print_log
from mmengine.optim import DefaultOptimWrapperConstructor

from mmaction.registry import OPTIM_WRAPPER_CONSTRUCTORS


@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class SwinOptimWrapperConstructor(DefaultOptimWrapperConstructor):

    def add_params(self,

                   params: List[dict],

                   module: nn.Module,

                   prefix: str = 'base',

                   **kwargs) -> None:
        """Add all parameters of module to the params list.



        The parameters of the given module will be added to the list of param

        groups, with specific rules defined by paramwise_cfg.



        Args:

            params (list[dict]): A list of param groups, it will be modified

                in place.

            module (nn.Module): The module to be added.

            prefix (str): The prefix of the module. Defaults to ``'base'``.

        """
        for name, param in module.named_parameters(recurse=False):
            param_group = {'params': [param]}
            if not param.requires_grad:
                params.append(param_group)
                continue

            param_group['lr'] = self.base_lr
            if self.base_wd is not None:
                param_group['weight_decay'] = self.base_wd

            processing_keys = [
                key for key in self.paramwise_cfg if key in f'{prefix}.{name}'
            ]
            if processing_keys:
                param_group['lr'] *= \
                    reduce(mul, [self.paramwise_cfg[key].get('lr_mult', 1.)
                                 for key in processing_keys])
                if self.base_wd is not None:
                    param_group['weight_decay'] *= \
                        reduce(mul, [self.paramwise_cfg[key].
                               get('decay_mult', 1.)
                                     for key in processing_keys])

            params.append(param_group)

            for key, value in param_group.items():
                if key == 'params':
                    continue
                full_name = f'{prefix}.{name}' if prefix else name
                print_log(
                    f'paramwise_options -- '
                    f'{full_name}: {key} = {round(value, 8)}',
                    logger='current')

        for child_name, child_mod in module.named_children():
            child_prefix = f'{prefix}.{child_name}' if prefix else child_name
            self.add_params(params, child_mod, prefix=child_prefix)