erwold commited on
Commit
1d4e763
·
1 Parent(s): 2d66916

Initial Commit

Browse files
Files changed (1) hide show
  1. app.py +116 -116
app.py CHANGED
@@ -54,126 +54,126 @@ class FluxInterface:
54
  self.MODEL_ID = "Djrango/Qwen2vl-Flux"
55
 
56
  def load_models(self):
57
- if self.models is not None:
58
- return
59
 
60
- import gc
61
- torch.cuda.empty_cache()
62
- gc.collect()
63
-
64
- logger.info("Starting model loading...")
65
-
66
- try:
67
- # 1. 首先加载小型模型和tokenizer
68
- tokenizer = CLIPTokenizer.from_pretrained(self.MODEL_ID, subfolder="flux/tokenizer")
69
- tokenizer_two = T5TokenizerFast.from_pretrained(self.MODEL_ID, subfolder="flux/tokenizer_2")
70
- scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(self.MODEL_ID, subfolder="flux/scheduler", shift=1)
71
-
72
- # 2. 加载并优化CLIP text encoder
73
- text_encoder = CLIPTextModel.from_pretrained(
74
- self.MODEL_ID,
75
- subfolder="flux/text_encoder",
76
- torch_dtype=self.dtype,
77
- device_map="auto" # 让模型自动管理显存
78
- )
79
-
80
- # 3. 加载T5 encoder
81
- text_encoder_two = T5EncoderModel.from_pretrained(
82
- self.MODEL_ID,
83
- subfolder="flux/text_encoder_2",
84
- torch_dtype=self.dtype,
85
- device_map="auto"
86
- )
87
-
88
- # 清理一次显存
89
- torch.cuda.empty_cache()
90
- gc.collect()
91
-
92
- # 4. 加载VAE
93
- vae = AutoencoderKL.from_pretrained(
94
- self.MODEL_ID,
95
- subfolder="flux/vae",
96
- torch_dtype=self.dtype,
97
- device_map="auto"
98
- )
99
-
100
- # 5. 加载Transformer
101
- transformer = FluxTransformer2DModel.from_pretrained(
102
- self.MODEL_ID,
103
- subfolder="flux/transformer",
104
- torch_dtype=self.dtype,
105
- device_map="auto"
106
- )
107
-
108
- # 再次清理显存
109
  torch.cuda.empty_cache()
110
  gc.collect()
111
-
112
- # 6. 加载Qwen2VL
113
- qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(
114
- self.MODEL_ID,
115
- subfolder="qwen2-vl",
116
- torch_dtype=self.dtype,
117
- device_map="auto"
118
- )
119
-
120
- # 7. 加载其他小组件
121
- connector = Qwen2Connector().to(self.dtype)
122
- connector_path = f"https://huggingface.co/{self.MODEL_ID}/resolve/main/qwen2-vl/connector.pt"
123
- connector_state = torch.hub.load_state_dict_from_url(connector_path, map_location='cpu')
124
- connector_state = {k: v.to(self.dtype) for k, v in connector_state.items()}
125
- connector.load_state_dict(connector_state)
126
- connector = connector.to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- self.t5_context_embedder = nn.Linear(4096, 3072).to(self.dtype)
129
- t5_embedder_path = f"https://huggingface.co/{self.MODEL_ID}/resolve/main/qwen2-vl/t5_embedder.pt"
130
- t5_embedder_state = torch.hub.load_state_dict_from_url(t5_embedder_path, map_location='cpu')
131
- t5_embedder_state = {k: v.to(self.dtype) for k, v in t5_embedder_state.items()}
132
- self.t5_context_embedder.load_state_dict(t5_embedder_state)
133
- self.t5_context_embedder = self.t5_context_embedder.to(self.device)
134
-
135
- # 设置eval模式和关闭梯度
136
- for model in [text_encoder, text_encoder_two, vae, transformer, qwen2vl, connector, self.t5_context_embedder]:
137
- if hasattr(model, 'eval'):
138
- model.eval()
139
- if hasattr(model, 'requires_grad_'):
140
- model.requires_grad_(False)
141
-
142
- logger.info("Models loaded successfully")
143
-
144
- self.models = {
145
- 'tokenizer': tokenizer,
146
- 'text_encoder': text_encoder,
147
- 'text_encoder_two': text_encoder_two,
148
- 'tokenizer_two': tokenizer_two,
149
- 'vae': vae,
150
- 'transformer': transformer,
151
- 'scheduler': scheduler,
152
- 'qwen2vl': qwen2vl,
153
- 'connector': connector
154
- }
155
-
156
- # 初始化processor和pipeline
157
- self.qwen2vl_processor = AutoProcessor.from_pretrained(
158
- self.MODEL_ID,
159
- subfolder="qwen2-vl",
160
- min_pixels=256*28*28,
161
- max_pixels=256*28*28
162
- )
163
-
164
- self.pipeline = FluxPipeline(
165
- transformer=transformer,
166
- scheduler=scheduler,
167
- vae=vae,
168
- text_encoder=text_encoder,
169
- tokenizer=tokenizer,
170
- )
171
-
172
- except Exception as e:
173
- logger.error(f"Error loading models: {str(e)}")
174
- torch.cuda.empty_cache()
175
- gc.collect()
176
- raise
177
 
178
  def resize_image(self, img, max_pixels=1050000):
179
  if not isinstance(img, Image.Image):
 
54
  self.MODEL_ID = "Djrango/Qwen2vl-Flux"
55
 
56
  def load_models(self):
57
+ if self.models is not None:
58
+ return
59
 
60
+ import gc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  torch.cuda.empty_cache()
62
  gc.collect()
63
+
64
+ logger.info("Starting model loading...")
65
+
66
+ try:
67
+ # 1. 首先加载小型模型和tokenizer
68
+ tokenizer = CLIPTokenizer.from_pretrained(self.MODEL_ID, subfolder="flux/tokenizer")
69
+ tokenizer_two = T5TokenizerFast.from_pretrained(self.MODEL_ID, subfolder="flux/tokenizer_2")
70
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(self.MODEL_ID, subfolder="flux/scheduler", shift=1)
71
+
72
+ # 2. 加载并优化CLIP text encoder
73
+ text_encoder = CLIPTextModel.from_pretrained(
74
+ self.MODEL_ID,
75
+ subfolder="flux/text_encoder",
76
+ torch_dtype=self.dtype,
77
+ device_map="auto" # 让模型自动管理显存
78
+ )
79
+
80
+ # 3. 加载T5 encoder
81
+ text_encoder_two = T5EncoderModel.from_pretrained(
82
+ self.MODEL_ID,
83
+ subfolder="flux/text_encoder_2",
84
+ torch_dtype=self.dtype,
85
+ device_map="auto"
86
+ )
87
+
88
+ # 清理一次显存
89
+ torch.cuda.empty_cache()
90
+ gc.collect()
91
+
92
+ # 4. 加载VAE
93
+ vae = AutoencoderKL.from_pretrained(
94
+ self.MODEL_ID,
95
+ subfolder="flux/vae",
96
+ torch_dtype=self.dtype,
97
+ device_map="auto"
98
+ )
99
+
100
+ # 5. 加载Transformer
101
+ transformer = FluxTransformer2DModel.from_pretrained(
102
+ self.MODEL_ID,
103
+ subfolder="flux/transformer",
104
+ torch_dtype=self.dtype,
105
+ device_map="auto"
106
+ )
107
+
108
+ # 再次清理显存
109
+ torch.cuda.empty_cache()
110
+ gc.collect()
111
+
112
+ # 6. 加载Qwen2VL
113
+ qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(
114
+ self.MODEL_ID,
115
+ subfolder="qwen2-vl",
116
+ torch_dtype=self.dtype,
117
+ device_map="auto"
118
+ )
119
+
120
+ # 7. 加载其他小组件
121
+ connector = Qwen2Connector().to(self.dtype)
122
+ connector_path = f"https://huggingface.co/{self.MODEL_ID}/resolve/main/qwen2-vl/connector.pt"
123
+ connector_state = torch.hub.load_state_dict_from_url(connector_path, map_location='cpu')
124
+ connector_state = {k: v.to(self.dtype) for k, v in connector_state.items()}
125
+ connector.load_state_dict(connector_state)
126
+ connector = connector.to(self.device)
127
 
128
+ self.t5_context_embedder = nn.Linear(4096, 3072).to(self.dtype)
129
+ t5_embedder_path = f"https://huggingface.co/{self.MODEL_ID}/resolve/main/qwen2-vl/t5_embedder.pt"
130
+ t5_embedder_state = torch.hub.load_state_dict_from_url(t5_embedder_path, map_location='cpu')
131
+ t5_embedder_state = {k: v.to(self.dtype) for k, v in t5_embedder_state.items()}
132
+ self.t5_context_embedder.load_state_dict(t5_embedder_state)
133
+ self.t5_context_embedder = self.t5_context_embedder.to(self.device)
134
+
135
+ # 设置eval模式和关闭梯度
136
+ for model in [text_encoder, text_encoder_two, vae, transformer, qwen2vl, connector, self.t5_context_embedder]:
137
+ if hasattr(model, 'eval'):
138
+ model.eval()
139
+ if hasattr(model, 'requires_grad_'):
140
+ model.requires_grad_(False)
141
+
142
+ logger.info("Models loaded successfully")
143
+
144
+ self.models = {
145
+ 'tokenizer': tokenizer,
146
+ 'text_encoder': text_encoder,
147
+ 'text_encoder_two': text_encoder_two,
148
+ 'tokenizer_two': tokenizer_two,
149
+ 'vae': vae,
150
+ 'transformer': transformer,
151
+ 'scheduler': scheduler,
152
+ 'qwen2vl': qwen2vl,
153
+ 'connector': connector
154
+ }
155
+
156
+ # 初始化processor和pipeline
157
+ self.qwen2vl_processor = AutoProcessor.from_pretrained(
158
+ self.MODEL_ID,
159
+ subfolder="qwen2-vl",
160
+ min_pixels=256*28*28,
161
+ max_pixels=256*28*28
162
+ )
163
+
164
+ self.pipeline = FluxPipeline(
165
+ transformer=transformer,
166
+ scheduler=scheduler,
167
+ vae=vae,
168
+ text_encoder=text_encoder,
169
+ tokenizer=tokenizer,
170
+ )
171
+
172
+ except Exception as e:
173
+ logger.error(f"Error loading models: {str(e)}")
174
+ torch.cuda.empty_cache()
175
+ gc.collect()
176
+ raise
177
 
178
  def resize_image(self, img, max_pixels=1050000):
179
  if not isinstance(img, Image.Image):