Spaces:
Runtime error
Runtime error
File size: 7,408 Bytes
49d4954 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import torch
import torch.nn as nn
from .transformer_flux import FluxTransformer2DModel
class FluxNetwork(nn.Module):
TARGET_REPLACE_MODULE = ["FluxTransformerBlock","FluxSingleTransformerBlock"] # 可训练的模块类型
FLUX_PREFIX = "flux"
def __init__(self, flux_model: FluxTransformer2DModel):
super().__init__()
self.flux_model = flux_model
self.trainable_component_names = [] # 用于记录可训练组件的名称
@staticmethod
def generate_trainable_components(layers, num_transformer_blocks=19):
transformer_components = [
"attn.to_q",
"attn.to_k",
"attn.to_v",
"attn.to_out",
"norm1",
"norm1_context",
]
single_transformer_components = [
"attn.to_q",
"attn.to_k",
"attn.to_v",
"norm",
#"proj_mlp",
]
components = ["context_embedder"] # 添加 context_embedder
for layer in layers:
if layer < num_transformer_blocks:
prefix = f"transformer_blocks.{layer}"
base_components = transformer_components
else:
prefix = f"single_transformer_blocks.{layer - num_transformer_blocks}"
base_components = single_transformer_components
components.extend([f"{prefix}.{comp}" for comp in base_components])
return components
#def apply_to(self, num_layers=1, additional_components=None):
# component_names = self.generate_trainable_components(num_layers)
#
# if additional_components:
# component_names.extend(additional_components)
#
# self.trainable_component_names = [] # 重置
# for name in component_names:
# recursive_getattr(self.flux_model, name).requires_grad_(True)
# self.trainable_component_names.append(name) # 记录名称
#def apply_to(self, num_layers=1, additional_components=None):
# component_names = self.generate_trainable_components(num_layers)
#
# if additional_components:
# component_names.extend(additional_components)
#
# self.trainable_component_names = [] # 重置
# for name in component_names:
# component = recursive_getattr(self.flux_model, name)
# if isinstance(component, nn.Module):
# component.requires_grad_(True)
# self.trainable_component_names.append(name)
# else:
# print(f"Warning: {name} is not a Module, skipping.")
def apply_to(self, layers=None, additional_components=None):
if layers is None:
layers = list(range(57)) # 默认包含所有层
component_names = self.generate_trainable_components(layers)
if additional_components:
component_names.extend(additional_components)
self.trainable_component_names = [] # 重置
for name in component_names:
try:
component = recursive_getattr(self.flux_model, name)
if isinstance(component, nn.Module):
component.requires_grad_(True)
self.trainable_component_names.append(name)
else:
print(f"Warning: {name} is not a Module, skipping.")
except AttributeError:
print(f"Warning: {name} not found in the model, skipping.")
def prepare_grad_etc(self):
# 供flux_model调用,用于冻结/解冻组件
self.flux_model.requires_grad_(False)
for name in self.trainable_component_names:
recursive_getattr(self.flux_model, name).requires_grad_(True)
def get_trainable_params(self):
# 返回需要训练的参数
params = []
for name in self.trainable_component_names:
params.extend(recursive_getattr(self.flux_model, name).parameters())
return params
def print_trainable_params_info(self):
total_params = 0
for name in self.trainable_component_names:
module = recursive_getattr(self.flux_model, name)
module_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
total_params += module_params
#print(f'{name}: {module_params} trainable parameters')
print(f'Total trainable params: {total_params}')
def save_weights(self, file, dtype=None):
# 保存需要训练的组件参数
state_dict = {}
for name in self.trainable_component_names:
state_dict[name] = recursive_getattr(self.flux_model, name).state_dict()
if dtype is not None:
for v in state_dict.values():
v = {k: t.detach().clone().to("cpu").to(dtype) for k, t in v.items()}
torch.save(state_dict, file)
#def load_weights(self, file):
# # 加载需要训练的组件参数
# state_dict = torch.load(file, weights_only=True)
# for name in state_dict:
# module = recursive_getattr(self.flux_model, name)
# module.load_state_dict(state_dict[name])
# print(f"加载参数: {name}")
def load_weights(self, file, device):
print(f"Loading weights from {file}")
try:
state_dict = torch.load(file, map_location=device, weights_only=True)
except Exception as e:
print(f"Failed to load weights from {file}: {str(e)}")
return False
successfully_loaded = []
failed_to_load = []
for name in state_dict:
try:
module = recursive_getattr(self.flux_model, name)
module_state_dict = module.state_dict()
# 检查state_dict的键是否匹配
if set(state_dict[name].keys()) != set(module_state_dict.keys()):
raise ValueError(f"State dict keys for {name} do not match")
# 检查张量的形状是否匹配
for key in state_dict[name]:
if state_dict[name][key].shape != module_state_dict[key].shape:
raise ValueError(f"Shape mismatch for {name}.{key}")
module.load_state_dict(state_dict[name])
successfully_loaded.append(name)
except Exception as e:
print(f"Failed to load weights for {name}: {str(e)}")
failed_to_load.append(name)
if successfully_loaded:
print(f"Successfully loaded weights for: {', '.join(successfully_loaded)}")
if failed_to_load:
print(f"Failed to load weights for: {', '.join(failed_to_load)}")
return len(failed_to_load) == 0 # 如果没有加载失败的组件,则返回True
# 改进的递归获取属性函数
def recursive_getattr(obj, attr):
attrs = attr.split(".")
for i in range(len(attrs)):
obj = getattr(obj, attrs[i])
return obj
# 递归设置属性函数
def recursive_setattr(obj, attr, val):
attrs = attr.split(".")
for i in range(len(attrs)-1):
obj = getattr(obj, attrs[i])
setattr(obj, attrs[-1], val) |