File size: 2,434 Bytes
8c70653
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from typing import Generator

from trainer.trainer_utils import get_optimizer


class CapacitronOptimizer:
    """Double optimizer class for the Capacitron model."""

    def __init__(self, config: dict, model_params: Generator) -> None:
        self.primary_params, self.secondary_params = self.split_model_parameters(model_params)

        optimizer_names = list(config.optimizer_params.keys())
        optimizer_parameters = list(config.optimizer_params.values())

        self.primary_optimizer = get_optimizer(
            optimizer_names[0],
            optimizer_parameters[0],
            config.lr,
            parameters=self.primary_params,
        )

        self.secondary_optimizer = get_optimizer(
            optimizer_names[1],
            self.extract_optimizer_parameters(optimizer_parameters[1]),
            optimizer_parameters[1]["lr"],
            parameters=self.secondary_params,
        )

        self.param_groups = self.primary_optimizer.param_groups

    def first_step(self):
        self.secondary_optimizer.step()
        self.secondary_optimizer.zero_grad()
        self.primary_optimizer.zero_grad()

    def step(self):
        # Update param groups to display the correct learning rate
        self.param_groups = self.primary_optimizer.param_groups
        self.primary_optimizer.step()

    def zero_grad(self, set_to_none=False):
        self.primary_optimizer.zero_grad(set_to_none)
        self.secondary_optimizer.zero_grad(set_to_none)

    def load_state_dict(self, state_dict):
        self.primary_optimizer.load_state_dict(state_dict[0])
        self.secondary_optimizer.load_state_dict(state_dict[1])

    def state_dict(self):
        return [self.primary_optimizer.state_dict(), self.secondary_optimizer.state_dict()]

    @staticmethod
    def split_model_parameters(model_params: Generator) -> list:
        primary_params = []
        secondary_params = []
        for name, param in model_params:
            if param.requires_grad:
                if name == "capacitron_vae_layer.beta":
                    secondary_params.append(param)
                else:
                    primary_params.append(param)
        return [iter(primary_params), iter(secondary_params)]

    @staticmethod
    def extract_optimizer_parameters(params: dict) -> dict:
        """Extract parameters that are not the learning rate"""
        return {k: v for k, v in params.items() if k != "lr"}