File size: 7,408 Bytes
49d4954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import torch
import torch.nn as nn
from .transformer_flux import FluxTransformer2DModel

class FluxNetwork(nn.Module):
    TARGET_REPLACE_MODULE = ["FluxTransformerBlock","FluxSingleTransformerBlock"] # 可训练的模块类型
    FLUX_PREFIX = "flux"
    
    def __init__(self, flux_model: FluxTransformer2DModel):
        super().__init__()
        self.flux_model = flux_model
        self.trainable_component_names = []  # 用于记录可训练组件的名称
      
    @staticmethod
    def generate_trainable_components(layers, num_transformer_blocks=19):
        transformer_components = [
            "attn.to_q",
            "attn.to_k",
            "attn.to_v",
            "attn.to_out",
            "norm1",
            "norm1_context",
        ]
        
        single_transformer_components = [
            "attn.to_q",
            "attn.to_k",
            "attn.to_v",
            "norm",
            #"proj_mlp",
        ]
        
        components = ["context_embedder"]  # 添加 context_embedder
        for layer in layers:
            if layer < num_transformer_blocks:
                prefix = f"transformer_blocks.{layer}"
                base_components = transformer_components
            else:
                prefix = f"single_transformer_blocks.{layer - num_transformer_blocks}"
                base_components = single_transformer_components
            components.extend([f"{prefix}.{comp}" for comp in base_components])
        
        return components
    
    #def apply_to(self, num_layers=1, additional_components=None):
    #    component_names = self.generate_trainable_components(num_layers)
    #    
    #    if additional_components:
    #        component_names.extend(additional_components)
    #    
    #    self.trainable_component_names = []  # 重置
    #    for name in component_names:
    #        recursive_getattr(self.flux_model, name).requires_grad_(True)
    #        self.trainable_component_names.append(name)  # 记录名称
    
    #def apply_to(self, num_layers=1, additional_components=None):
    #    component_names = self.generate_trainable_components(num_layers)
    #    
    #    if additional_components:
    #        component_names.extend(additional_components)
    #    
    #    self.trainable_component_names = []  # 重置
    #    for name in component_names:
    #        component = recursive_getattr(self.flux_model, name)
    #        if isinstance(component, nn.Module):
    #            component.requires_grad_(True)
    #            self.trainable_component_names.append(name)
    #        else:
    #            print(f"Warning: {name} is not a Module, skipping.")
    
    def apply_to(self, layers=None, additional_components=None):
        if layers is None:
            layers = list(range(57))  # 默认包含所有层
        
        component_names = self.generate_trainable_components(layers)
        
        if additional_components:
            component_names.extend(additional_components)
        
        self.trainable_component_names = []  # 重置
        for name in component_names:
            try:
                component = recursive_getattr(self.flux_model, name)
                if isinstance(component, nn.Module):
                    component.requires_grad_(True)
                    self.trainable_component_names.append(name)
                else:
                    print(f"Warning: {name} is not a Module, skipping.")
            except AttributeError:
                print(f"Warning: {name} not found in the model, skipping.")
                        
    def prepare_grad_etc(self):
        # 供flux_model调用,用于冻结/解冻组件
        self.flux_model.requires_grad_(False)
        for name in self.trainable_component_names:
            recursive_getattr(self.flux_model, name).requires_grad_(True)
                
    def get_trainable_params(self):
        # 返回需要训练的参数
        params = []
        for name in self.trainable_component_names:
            params.extend(recursive_getattr(self.flux_model, name).parameters())
        return params
    
    def print_trainable_params_info(self):
        total_params = 0
        for name in self.trainable_component_names:
            module = recursive_getattr(self.flux_model, name)
            module_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
            total_params += module_params
            #print(f'{name}: {module_params} trainable parameters')
        print(f'Total trainable params: {total_params}')
    
    def save_weights(self, file, dtype=None):
        # 保存需要训练的组件参数
        state_dict = {}
        for name in self.trainable_component_names:
            state_dict[name] = recursive_getattr(self.flux_model, name).state_dict()
        if dtype is not None:
            for v in state_dict.values():
                v = {k: t.detach().clone().to("cpu").to(dtype) for k, t in v.items()}        
        torch.save(state_dict, file)

    #def load_weights(self, file):    
    #    # 加载需要训练的组件参数
    #    state_dict = torch.load(file, weights_only=True)
    #    for name in state_dict:
    #        module = recursive_getattr(self.flux_model, name)
    #        module.load_state_dict(state_dict[name])
    #        print(f"加载参数: {name}")
    
    def load_weights(self, file, device):
        print(f"Loading weights from {file}")
        try:
            state_dict = torch.load(file, map_location=device, weights_only=True)
        except Exception as e:
            print(f"Failed to load weights from {file}: {str(e)}")
            return False

        successfully_loaded = []
        failed_to_load = []

        for name in state_dict:
            try:
                module = recursive_getattr(self.flux_model, name)
                module_state_dict = module.state_dict()
                
                # 检查state_dict的键是否匹配
                if set(state_dict[name].keys()) != set(module_state_dict.keys()):
                    raise ValueError(f"State dict keys for {name} do not match")
                
                # 检查张量的形状是否匹配
                for key in state_dict[name]:
                    if state_dict[name][key].shape != module_state_dict[key].shape:
                        raise ValueError(f"Shape mismatch for {name}.{key}")
                
                module.load_state_dict(state_dict[name])
                successfully_loaded.append(name)
                 
            except Exception as e:
                print(f"Failed to load weights for {name}: {str(e)}")
                failed_to_load.append(name)

        if successfully_loaded:
            print(f"Successfully loaded weights for: {', '.join(successfully_loaded)}")
        if failed_to_load:
            print(f"Failed to load weights for: {', '.join(failed_to_load)}")

        return len(failed_to_load) == 0  # 如果没有加载失败的组件,则返回True
            
# 改进的递归获取属性函数
def recursive_getattr(obj, attr):
    attrs = attr.split(".")
    for i in range(len(attrs)):
        obj = getattr(obj, attrs[i]) 
    return obj

# 递归设置属性函数
def recursive_setattr(obj, attr, val):
    attrs = attr.split(".")
    for i in range(len(attrs)-1):
        obj = getattr(obj, attrs[i])
    setattr(obj, attrs[-1], val)