alessandro trinca tornidor commited on
Commit
8ced4d2
·
1 Parent(s): 05528cb

feat: try moving lisa-on-cuda code within the project

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +1 -1
  2. lisa_on_cuda/LISA.py +471 -0
  3. lisa_on_cuda/__init__.py +18 -0
  4. lisa_on_cuda/__version__.py +8 -0
  5. lisa_on_cuda/llava/__init__.py +1 -0
  6. lisa_on_cuda/llava/constants.py +12 -0
  7. lisa_on_cuda/llava/conversation.py +399 -0
  8. lisa_on_cuda/llava/mm_utils.py +88 -0
  9. lisa_on_cuda/llava/model/__init__.py +2 -0
  10. lisa_on_cuda/llava/model/apply_delta.py +57 -0
  11. lisa_on_cuda/llava/model/builder.py +206 -0
  12. lisa_on_cuda/llava/model/consolidate.py +30 -0
  13. lisa_on_cuda/llava/model/language_model/llava_llama.py +167 -0
  14. lisa_on_cuda/llava/model/language_model/llava_mpt.py +174 -0
  15. lisa_on_cuda/llava/model/language_model/mpt/adapt_tokenizer.py +46 -0
  16. lisa_on_cuda/llava/model/language_model/mpt/attention.py +526 -0
  17. lisa_on_cuda/llava/model/language_model/mpt/blocks.py +92 -0
  18. lisa_on_cuda/llava/model/language_model/mpt/configuration_mpt.py +199 -0
  19. lisa_on_cuda/llava/model/language_model/mpt/custom_embedding.py +11 -0
  20. lisa_on_cuda/llava/model/language_model/mpt/flash_attn_triton.py +1087 -0
  21. lisa_on_cuda/llava/model/language_model/mpt/hf_prefixlm_converter.py +750 -0
  22. lisa_on_cuda/llava/model/language_model/mpt/meta_init_context.py +111 -0
  23. lisa_on_cuda/llava/model/language_model/mpt/modeling_mpt.py +538 -0
  24. lisa_on_cuda/llava/model/language_model/mpt/norm.py +106 -0
  25. lisa_on_cuda/llava/model/language_model/mpt/param_init_fns.py +419 -0
  26. lisa_on_cuda/llava/model/llava_arch.py +395 -0
  27. lisa_on_cuda/llava/model/make_delta.py +63 -0
  28. lisa_on_cuda/llava/model/multimodal_encoder/builder.py +17 -0
  29. lisa_on_cuda/llava/model/multimodal_encoder/clip_encoder.py +87 -0
  30. lisa_on_cuda/llava/model/utils.py +26 -0
  31. lisa_on_cuda/llava/train/llama_flash_attn_monkey_patch.py +126 -0
  32. lisa_on_cuda/llava/train/llava_trainer.py +67 -0
  33. lisa_on_cuda/llava/train/train.py +1038 -0
  34. lisa_on_cuda/llava/train/train_mem.py +14 -0
  35. lisa_on_cuda/llava/utils.py +134 -0
  36. lisa_on_cuda/routes.py +21 -0
  37. lisa_on_cuda/segment_anything/__init__.py +10 -0
  38. lisa_on_cuda/segment_anything/automatic_mask_generator.py +372 -0
  39. lisa_on_cuda/segment_anything/build_sam.py +108 -0
  40. lisa_on_cuda/segment_anything/modeling/__init__.py +11 -0
  41. lisa_on_cuda/segment_anything/modeling/common.py +43 -0
  42. lisa_on_cuda/segment_anything/modeling/image_encoder.py +426 -0
  43. lisa_on_cuda/segment_anything/modeling/mask_decoder.py +191 -0
  44. lisa_on_cuda/segment_anything/modeling/prompt_encoder.py +238 -0
  45. lisa_on_cuda/segment_anything/modeling/sam.py +184 -0
  46. lisa_on_cuda/segment_anything/modeling/transformer.py +242 -0
  47. lisa_on_cuda/segment_anything/predictor.py +284 -0
  48. lisa_on_cuda/segment_anything/utils/__init__.py +5 -0
  49. lisa_on_cuda/segment_anything/utils/amg.py +346 -0
  50. lisa_on_cuda/segment_anything/utils/onnx.py +157 -0
app.py CHANGED
@@ -120,7 +120,7 @@ async def health() -> JSONResponse:
120
  return JSONResponse(status_code=200, content={"msg": "still alive..."})
121
 
122
 
123
- # try executing gpu_initialization() not within infer_lisa_gradio()
124
  # gpu_initialization()
125
 
126
  @spaces.GPU
 
120
  return JSONResponse(status_code=200, content={"msg": "still alive..."})
121
 
122
 
123
+ # try executingx gpu_initialization() not within infer_lisa_gradio()
124
  # gpu_initialization()
125
 
126
  @spaces.GPU
lisa_on_cuda/LISA.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from .llava.model.language_model.llava_llama import (LlavaLlamaForCausalLM, LlavaLlamaModel)
8
+ from .segment_anything import build_sam_vit_h
9
+
10
+ embedding_dict = {}
11
+
12
+
13
+ def dice_loss(
14
+ inputs: torch.Tensor,
15
+ targets: torch.Tensor,
16
+ num_masks: float,
17
+ scale=1000, # 100000.0,
18
+ eps=1e-6,
19
+ ) -> torch.Tensor:
20
+ """
21
+ Compute the DICE loss, similar to generalized IOU for masks.
22
+ Arguments 'num_masks', 'scale', 'eps' and return value 'loss' are undocumented in original project
23
+ https://github.com/dvlab-research/LISA
24
+ About 'num_masks': it's similar to 'avg_factor' in weight_reduce_loss() from
25
+ https://github.com/open-mmlab/mmdetection/blob/e9cae2d0787cd5c2fc6165a6061f92fa09e48fb1/mmdet/models/losses/utils.py#L30
26
+
27
+ Args:
28
+ inputs: A float tensor of arbitrary shape.
29
+ The predictions for each example.
30
+ targets: A float tensor with the same shape as inputs. Stores the binary
31
+ classification label for each element in inputs
32
+ (0 for the negative class and 1 for the positive class).
33
+ num_masks: Average factor when computing the mean of losses (?)
34
+ scale: weight factor applied before computing mean of losses (?)
35
+ eps: Avoid dividing by zero (?)
36
+
37
+ return:
38
+ Processed loss values.
39
+
40
+ """
41
+ inputs = inputs.sigmoid()
42
+ inputs = inputs.flatten(1, 2)
43
+ targets = targets.flatten(1, 2)
44
+ numerator = 2 * (inputs / scale * targets).sum(-1)
45
+ denominator = (inputs / scale).sum(-1) + (targets / scale).sum(-1)
46
+ loss = 1 - (numerator + eps) / (denominator + eps)
47
+
48
+ loss = loss.sum() / (num_masks + 1e-8)
49
+
50
+ return loss
51
+
52
+
53
+ def sigmoid_ce_loss(
54
+ inputs: torch.Tensor,
55
+ targets: torch.Tensor,
56
+ num_masks: float,
57
+ ):
58
+ """
59
+ Args:
60
+ inputs: A float tensor of arbitrary shape.
61
+ The predictions for each example.
62
+ targets: A float tensor with the same shape as inputs. Stores the binary
63
+ classification label for each element in inputs
64
+ (0 for the negative class and 1 for the positive class).
65
+ num_masks: Average factor when computing the mean of losses (?)
66
+
67
+ Returns:
68
+ Loss tensor
69
+ """
70
+ loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
71
+ loss = loss.flatten(1, 2).mean(1).sum() / (num_masks + 1e-8)
72
+ return loss
73
+
74
+
75
+ class LisaMetaModel:
76
+ def __init__(
77
+ self,
78
+ config,
79
+ **kwargs,
80
+ ):
81
+ super(LisaMetaModel, self).__init__(config)
82
+
83
+ self.config = config
84
+ if not hasattr(self.config, "train_mask_decoder"):
85
+ self.config.train_mask_decoder = kwargs["train_mask_decoder"]
86
+ self.config.out_dim = kwargs["out_dim"]
87
+ self.vision_pretrained = kwargs.get("vision_pretrained", None)
88
+ else:
89
+ self.vision_pretrained = kwargs.get("vision_pretrained", None)
90
+ self.initialize_lisa_modules(self.config)
91
+
92
+ def initialize_lisa_modules(self, config):
93
+ # SAM
94
+ self.visual_model = build_sam_vit_h(self.vision_pretrained)
95
+ for param in self.visual_model.parameters():
96
+ param.requires_grad = False
97
+ if config.train_mask_decoder:
98
+ self.visual_model.mask_decoder.train()
99
+ for param in self.visual_model.mask_decoder.parameters():
100
+ param.requires_grad = True
101
+
102
+ # Projection layer
103
+ in_dim = config.hidden_size
104
+ out_dim = config.out_dim
105
+ text_fc = [
106
+ nn.Linear(in_dim, in_dim),
107
+ nn.ReLU(inplace=True),
108
+ nn.Linear(in_dim, out_dim),
109
+ nn.Dropout(0.0),
110
+ ]
111
+ self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)])
112
+ self.text_hidden_fcs.train()
113
+ for param in self.text_hidden_fcs.parameters():
114
+ param.requires_grad = True
115
+
116
+
117
+ class LisaModel(LisaMetaModel, LlavaLlamaModel):
118
+ def __init__(
119
+ self,
120
+ config,
121
+ **kwargs,
122
+ ):
123
+ super(LisaModel, self).__init__(config, **kwargs)
124
+
125
+ self.config.use_cache = False
126
+ self.config.vision_tower = self.config.mm_vision_tower
127
+ self.config.mm_vision_select_feature = "patch"
128
+ self.config.image_aspect_ratio = "square"
129
+ self.config.image_grid_pinpoints = None
130
+ self.config.tune_mm_mlp_adapter = False
131
+ self.config.freeze_mm_mlp_adapter = True
132
+ self.config.pretrain_mm_mlp_adapter = None
133
+ self.config.mm_use_im_patch_token = False
134
+
135
+
136
+ class LISAForCausalLM(LlavaLlamaForCausalLM):
137
+ def __init__(
138
+ self,
139
+ config,
140
+ **kwargs,
141
+ ):
142
+ if not hasattr(config, "train_mask_decoder"):
143
+ config.mm_use_im_start_end = kwargs.pop("use_mm_start_end", True)
144
+ config.mm_vision_tower = kwargs.get(
145
+ "vision_tower", "openai/clip-vit-large-patch14"
146
+ )
147
+ self.ce_loss_weight = kwargs.pop("ce_loss_weight", None)
148
+ self.dice_loss_weight = kwargs.pop("dice_loss_weight", None)
149
+ self.bce_loss_weight = kwargs.pop("bce_loss_weight", None)
150
+ else:
151
+ config.mm_vision_tower = config.vision_tower
152
+
153
+ self.seg_token_idx = kwargs.pop("seg_token_idx")
154
+
155
+ super().__init__(config)
156
+
157
+ self.model = LisaModel(config, **kwargs)
158
+
159
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
160
+
161
+ # Initialize weights and apply final processing
162
+ self.post_init()
163
+
164
+ def get_visual_embs(self, pixel_values: torch.FloatTensor):
165
+ with torch.no_grad():
166
+ image_embeddings_list = []
167
+ for i in range(pixel_values.shape[0]):
168
+ torch.cuda.empty_cache()
169
+ image_embeddings = self.model.visual_model.image_encoder(
170
+ pixel_values[i].unsqueeze(0)
171
+ )
172
+ image_embeddings_list.append(image_embeddings)
173
+ torch.cuda.empty_cache()
174
+ image_embeddings = torch.cat(image_embeddings_list, 0)
175
+ return image_embeddings
176
+
177
+ def forward(self, **kwargs):
178
+ if "past_key_values" in kwargs:
179
+ return super().forward(**kwargs)
180
+ return self.model_forward(**kwargs)
181
+
182
+ def model_forward(
183
+ self,
184
+ images: torch.FloatTensor,
185
+ images_clip: torch.FloatTensor,
186
+ input_ids: torch.LongTensor,
187
+ labels: torch.LongTensor,
188
+ attention_masks: torch.LongTensor,
189
+ offset: torch.LongTensor,
190
+ masks_list: List[torch.FloatTensor],
191
+ label_list: List[torch.Tensor],
192
+ resize_list: List[tuple],
193
+ inference: bool = False,
194
+ **kwargs,
195
+ ):
196
+ image_embeddings = self.get_visual_embs(images)
197
+ batch_size = image_embeddings.shape[0]
198
+ assert batch_size == len(offset) - 1
199
+
200
+ seg_token_mask = input_ids[:, 1:] == self.seg_token_idx
201
+ seg_token_mask = torch.cat(
202
+ [
203
+ seg_token_mask,
204
+ torch.zeros((seg_token_mask.shape[0], 1)).bool().cuda(),
205
+ ],
206
+ dim=1,
207
+ )
208
+ # hack for IMAGE_TOKEN_INDEX (we suppose that there is only one image, and it is in the front)
209
+ seg_token_mask = torch.cat(
210
+ [torch.zeros((seg_token_mask.shape[0], 255)).bool().cuda(), seg_token_mask],
211
+ dim=1,
212
+ )
213
+
214
+ if inference:
215
+ n_batch = 1
216
+ length = input_ids.shape[0]
217
+ assert images_clip.shape[0] == 1
218
+ images_clip_extend = images_clip.expand(length, -1, -1, -1).contiguous()
219
+
220
+ output_hidden_states = []
221
+ for i in range(n_batch):
222
+ start_i, end_i = i * length, min((i + 1) * length, input_ids.shape[0])
223
+ output_i = super().forward(
224
+ images=images_clip_extend[: end_i - start_i],
225
+ attention_mask=attention_masks[start_i:end_i],
226
+ input_ids=input_ids[start_i:end_i],
227
+ output_hidden_states=True,
228
+ )
229
+ output_hidden_states.append(output_i.hidden_states)
230
+ torch.cuda.empty_cache()
231
+
232
+ output_hidden_states_list = []
233
+ output_hidden_states_level = torch.cat(output_hidden_states, dim=0)
234
+ output_hidden_states_list.append(output_hidden_states_level)
235
+ output_hidden_states = output_hidden_states_list
236
+ output = None
237
+
238
+ else:
239
+ images_clip_list = []
240
+ for i in range(len(offset) - 1):
241
+ start_i, end_i = offset[i], offset[i + 1]
242
+ images_clip_i = (
243
+ images_clip[i]
244
+ .unsqueeze(0)
245
+ .expand(end_i - start_i, -1, -1, -1)
246
+ .contiguous()
247
+ )
248
+ images_clip_list.append(images_clip_i)
249
+ images_clip = torch.cat(images_clip_list, dim=0)
250
+
251
+ output = super().forward(
252
+ images=images_clip,
253
+ attention_mask=attention_masks,
254
+ input_ids=input_ids,
255
+ labels=labels,
256
+ output_hidden_states=True,
257
+ )
258
+ output_hidden_states = output.hidden_states
259
+
260
+ hidden_states = []
261
+
262
+ assert len(self.model.text_hidden_fcs) == 1
263
+ hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states[-1]))
264
+
265
+ last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
266
+ pred_embeddings = last_hidden_state[seg_token_mask]
267
+ seg_token_counts = seg_token_mask.int().sum(-1) # [bs, ]
268
+
269
+ seg_token_offset = seg_token_counts.cumsum(-1)
270
+ seg_token_offset = torch.cat(
271
+ [torch.zeros(1).long().cuda(), seg_token_offset], dim=0
272
+ )
273
+
274
+ seg_token_offset = seg_token_offset[offset]
275
+
276
+ pred_embeddings_ = []
277
+ for i in range(len(seg_token_offset) - 1):
278
+ start_i, end_i = seg_token_offset[i], seg_token_offset[i + 1]
279
+ pred_embeddings_.append(pred_embeddings[start_i:end_i])
280
+ pred_embeddings = pred_embeddings_
281
+
282
+ multimask_output = False
283
+ pred_masks = []
284
+ for i in range(len(pred_embeddings)):
285
+ (
286
+ sparse_embeddings,
287
+ dense_embeddings,
288
+ ) = self.model.visual_model.prompt_encoder(
289
+ points=None,
290
+ boxes=None,
291
+ masks=None,
292
+ text_embeds=pred_embeddings[i].unsqueeze(1),
293
+ )
294
+ sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
295
+ low_res_masks, iou_predictions = self.model.visual_model.mask_decoder(
296
+ image_embeddings=image_embeddings[i].unsqueeze(0),
297
+ image_pe=self.model.visual_model.prompt_encoder.get_dense_pe(),
298
+ sparse_prompt_embeddings=sparse_embeddings,
299
+ dense_prompt_embeddings=dense_embeddings,
300
+ multimask_output=multimask_output,
301
+ )
302
+ pred_mask = self.model.visual_model.postprocess_masks(
303
+ low_res_masks,
304
+ input_size=resize_list[i],
305
+ original_size=label_list[i].shape,
306
+ )
307
+ pred_masks.append(pred_mask[:, 0])
308
+
309
+ model_output = output
310
+ gt_masks = masks_list
311
+
312
+ if inference:
313
+ return {
314
+ "pred_masks": pred_masks,
315
+ "gt_masks": gt_masks,
316
+ }
317
+
318
+ output = model_output.logits
319
+
320
+ ce_loss = model_output.loss
321
+ ce_loss = ce_loss * self.ce_loss_weight
322
+ mask_bce_loss = 0
323
+ mask_dice_loss = 0
324
+ num_masks = 0
325
+ for batch_idx in range(len(pred_masks)):
326
+ gt_mask = gt_masks[batch_idx]
327
+ pred_mask = pred_masks[batch_idx]
328
+
329
+ assert (
330
+ gt_mask.shape[0] == pred_mask.shape[0]
331
+ ), "gt_mask.shape: {}, pred_mask.shape: {}".format(
332
+ gt_mask.shape, pred_mask.shape
333
+ )
334
+ mask_bce_loss += (
335
+ sigmoid_ce_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0])
336
+ * gt_mask.shape[0]
337
+ )
338
+ mask_dice_loss += (
339
+ dice_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0])
340
+ * gt_mask.shape[0]
341
+ )
342
+ num_masks += gt_mask.shape[0]
343
+
344
+ mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (num_masks + 1e-8)
345
+ mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8)
346
+ mask_loss = mask_bce_loss + mask_dice_loss
347
+
348
+ loss = ce_loss + mask_loss
349
+
350
+ return {
351
+ "loss": loss,
352
+ "ce_loss": ce_loss,
353
+ "mask_bce_loss": mask_bce_loss,
354
+ "mask_dice_loss": mask_dice_loss,
355
+ "mask_loss": mask_loss,
356
+ }
357
+
358
+ def evaluate(
359
+ self,
360
+ images_clip,
361
+ images,
362
+ input_ids,
363
+ resize_list,
364
+ original_size_list,
365
+ max_new_tokens=32,
366
+ tokenizer=None,
367
+ model_logger=None,
368
+ embedding_key=None
369
+ ):
370
+ with torch.no_grad():
371
+ if model_logger is None:
372
+ import logging
373
+ model_logger = logging
374
+ model_logger.debug("start output generation...")
375
+ outputs = self.generate(
376
+ images=images_clip,
377
+ input_ids=input_ids,
378
+ max_new_tokens=max_new_tokens,
379
+ num_beams=1,
380
+ output_hidden_states=True,
381
+ return_dict_in_generate=True,
382
+ )
383
+ model_logger.debug("done output generation...")
384
+ output_hidden_states = outputs.hidden_states[-1]
385
+ output_ids = outputs.sequences
386
+
387
+ seg_token_mask = output_ids[:, 1:] == self.seg_token_idx
388
+ # hack for IMAGE_TOKEN_INDEX (we suppose that there is only one image, and it is in the front)
389
+ model_logger.debug(f"start torch.cat to seg_token_mask...")
390
+ seg_token_mask = torch.cat(
391
+ [
392
+ torch.zeros((seg_token_mask.shape[0], 255)).bool().cuda(),
393
+ seg_token_mask,
394
+ ],
395
+ dim=1,
396
+ )
397
+ model_logger.debug("done torch.cat to seg_token_mask...")
398
+
399
+ hidden_states = []
400
+
401
+ assert len(self.model.text_hidden_fcs) == 1
402
+ hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states))
403
+
404
+ model_logger.debug("start torch.stack to last_hidden_state...")
405
+ last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
406
+ model_logger.debug("done torch.stack to last_hidden_state...")
407
+ pred_embeddings = last_hidden_state[seg_token_mask]
408
+
409
+ seg_token_counts = seg_token_mask.int().sum(-1) # [bs, ]
410
+ seg_token_offset = seg_token_counts.cumsum(-1)
411
+ model_logger.debug(f"start torch.cat to seg_token_offset...")
412
+ seg_token_offset = torch.cat(
413
+ [torch.zeros(1).long().cuda(), seg_token_offset], dim=0
414
+ )
415
+ model_logger.debug("done torch.cat to seg_token_offset...")
416
+
417
+ pred_embeddings_ = []
418
+ for i in range(len(seg_token_offset) - 1):
419
+ start_i, end_i = seg_token_offset[i], seg_token_offset[i + 1]
420
+ pred_embeddings_.append(pred_embeddings[start_i:end_i])
421
+ pred_embeddings = pred_embeddings_
422
+
423
+ model_logger.debug(f"start get_visual_embs to image_embeddings with embedding_key {embedding_key}.")
424
+
425
+ if embedding_key is None:
426
+ image_embeddings = self.get_visual_embs(images)
427
+ else:
428
+ try:
429
+ image_embeddings = embedding_dict[embedding_key]
430
+ except KeyError:
431
+ model_logger.debug(f"embedding_key {embedding_key} not in embedding_dict, creating embedding now!")
432
+ image_embeddings = self.get_visual_embs(images)
433
+ embedding_dict[embedding_key] = image_embeddings
434
+ model_logger.debug(f"image embedding added in embedding_dict with embedding_key {embedding_key}!")
435
+
436
+ model_logger.debug("done get_visual_embs to image_embeddings...")
437
+
438
+ multimask_output = False
439
+ pred_masks = []
440
+ for i in range(len(pred_embeddings)):
441
+ model_logger.debug(f"start ({i}nth time) visual_model.prompt_encoder to sparse/dense")
442
+ (
443
+ sparse_embeddings,
444
+ dense_embeddings,
445
+ ) = self.model.visual_model.prompt_encoder(
446
+ points=None,
447
+ boxes=None,
448
+ masks=None,
449
+ text_embeds=pred_embeddings[i].unsqueeze(1),
450
+ )
451
+ model_logger.debug(f"done ({i}nth) visual_model.prompt_encoder to sparse/dense, start sparse2sparse")
452
+ sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
453
+ model_logger.debug(f"done ({i}nth) sparse2sparse, start visual_model.mask_decoder")
454
+ low_res_masks, iou_predictions = self.model.visual_model.mask_decoder(
455
+ image_embeddings=image_embeddings[i].unsqueeze(0),
456
+ image_pe=self.model.visual_model.prompt_encoder.get_dense_pe(),
457
+ sparse_prompt_embeddings=sparse_embeddings,
458
+ dense_prompt_embeddings=dense_embeddings,
459
+ multimask_output=multimask_output,
460
+ )
461
+ model_logger.debug(f"done ({i}nth) visual_model.mask_decoder, start postprocess_masks")
462
+ pred_mask = self.model.visual_model.postprocess_masks(
463
+ low_res_masks,
464
+ input_size=resize_list[i],
465
+ original_size=original_size_list[i],
466
+ )
467
+ model_logger.debug(f"done ({i}nth) postprocess_masks")
468
+ pred_masks.append(pred_mask[:, 0])
469
+
470
+ model_logger.debug(f"env evaluate! ")
471
+ return output_ids, pred_masks
lisa_on_cuda/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import structlog
5
+ from dotenv import load_dotenv
6
+ from samgis_core.utilities.session_logger import setup_logging
7
+
8
+
9
+ load_dotenv()
10
+ project_root_folder = Path(globals().get("__file__", "./_")).absolute().parent
11
+ workdir = Path(os.getenv("WORKDIR", project_root_folder))
12
+ static_dist_folder = Path(workdir) / "static" / "dist"
13
+ static_dist_folder = Path(os.getenv("FASTAPI_STATIC", static_dist_folder))
14
+ model_folder = Path(project_root_folder / "machine_learning_models")
15
+
16
+ log_level = os.getenv("LOG_LEVEL", "INFO")
17
+ setup_logging(log_level=log_level)
18
+ app_logger = structlog.stdlib.get_logger()
lisa_on_cuda/__version__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import importlib.metadata
2
+
3
+
4
+ try:
5
+ __version__ = importlib.metadata.version(__package__ or __name__)
6
+ except importlib.metadata.PackageNotFoundError or ImportError as e:
7
+ print(f"metadata::e: {type(e)}, {e}: package installed?")
8
+ __version__ = "1.0.0"
lisa_on_cuda/llava/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import LlavaLlamaForCausalLM
lisa_on_cuda/llava/constants.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<image>"
10
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
11
+ DEFAULT_IM_START_TOKEN = "<im_start>"
12
+ DEFAULT_IM_END_TOKEN = "<im_end>"
lisa_on_cuda/llava/conversation.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import Enum, auto
3
+ from typing import List, Tuple
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+
9
+ SINGLE = auto()
10
+ TWO = auto()
11
+ MPT = auto()
12
+ PLAIN = auto()
13
+ LLAMA_2 = auto()
14
+
15
+
16
+ @dataclasses.dataclass
17
+ class Conversation:
18
+ """A class that keeps all conversation history."""
19
+
20
+ system: str
21
+ roles: List[str]
22
+ messages: List[List[str]]
23
+ offset: int
24
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
25
+ sep: str = "###"
26
+ sep2: str = None
27
+ version: str = "Unknown"
28
+
29
+ skip_next: bool = False
30
+
31
+ def get_prompt(self):
32
+ messages = self.messages
33
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
34
+ messages = self.messages.copy()
35
+ init_role, init_msg = messages[0].copy()
36
+ init_msg = init_msg[0].replace("<image>", "").strip()
37
+ if "mmtag" in self.version:
38
+ messages[0] = (init_role, init_msg)
39
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
40
+ messages.insert(1, (self.roles[1], "Received."))
41
+ else:
42
+ messages[0] = (init_role, "<image>\n" + init_msg)
43
+
44
+ if self.sep_style == SeparatorStyle.SINGLE:
45
+ ret = self.system + self.sep
46
+ for role, message in messages:
47
+ if message:
48
+ if type(message) is tuple:
49
+ message, _, _ = message
50
+ ret += role + ": " + message + self.sep
51
+ else:
52
+ ret += role + ":"
53
+ elif self.sep_style == SeparatorStyle.TWO:
54
+ seps = [self.sep, self.sep2]
55
+ ret = self.system + seps[0]
56
+ for i, (role, message) in enumerate(messages):
57
+ if message:
58
+ if type(message) is tuple:
59
+ message, _, _ = message
60
+ ret += role + ": " + message + seps[i % 2]
61
+ else:
62
+ ret += role + ":"
63
+ elif self.sep_style == SeparatorStyle.MPT:
64
+ ret = self.system + self.sep
65
+ for role, message in messages:
66
+ if message:
67
+ if type(message) is tuple:
68
+ message, _, _ = message
69
+ ret += role + message + self.sep
70
+ else:
71
+ ret += role
72
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
73
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
74
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
75
+ ret = ""
76
+
77
+ for i, (role, message) in enumerate(messages):
78
+ if i == 0:
79
+ assert message, "first message should not be none"
80
+ assert role == self.roles[0], "first message should come from user"
81
+ if message:
82
+ if type(message) is tuple:
83
+ message, _, _ = message
84
+ if i == 0:
85
+ message = wrap_sys(self.system) + message
86
+ if i % 2 == 0:
87
+ message = wrap_inst(message)
88
+ ret += self.sep + message
89
+ else:
90
+ ret += " " + message + " " + self.sep2
91
+ else:
92
+ ret += ""
93
+ ret = ret.lstrip(self.sep)
94
+ elif self.sep_style == SeparatorStyle.PLAIN:
95
+ seps = [self.sep, self.sep2]
96
+ ret = self.system
97
+ for i, (role, message) in enumerate(messages):
98
+ if message:
99
+ if type(message) is tuple:
100
+ message, _, _ = message
101
+ ret += message + seps[i % 2]
102
+ else:
103
+ ret += ""
104
+ else:
105
+ raise ValueError(f"Invalid style: {self.sep_style}")
106
+
107
+ return ret
108
+
109
+ def append_message(self, role, message):
110
+ self.messages.append([role, message])
111
+
112
+ def get_images(self, return_pil=False):
113
+ images = []
114
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
115
+ if i % 2 == 0:
116
+ if type(msg) is tuple:
117
+ import base64
118
+ from io import BytesIO
119
+
120
+ from PIL import Image
121
+
122
+ msg, image, image_process_mode = msg
123
+ if image_process_mode == "Pad":
124
+
125
+ def expand2square(pil_img, background_color=(122, 116, 104)):
126
+ width, height = pil_img.size
127
+ if width == height:
128
+ return pil_img
129
+ elif width > height:
130
+ result = Image.new(
131
+ pil_img.mode, (width, width), background_color
132
+ )
133
+ result.paste(pil_img, (0, (width - height) // 2))
134
+ return result
135
+ else:
136
+ result = Image.new(
137
+ pil_img.mode, (height, height), background_color
138
+ )
139
+ result.paste(pil_img, ((height - width) // 2, 0))
140
+ return result
141
+
142
+ image = expand2square(image)
143
+ elif image_process_mode == "Crop":
144
+ pass
145
+ elif image_process_mode == "Resize":
146
+ image = image.resize((336, 336))
147
+ else:
148
+ raise ValueError(
149
+ f"Invalid image_process_mode: {image_process_mode}"
150
+ )
151
+ max_hw, min_hw = max(image.size), min(image.size)
152
+ aspect_ratio = max_hw / min_hw
153
+ max_len, min_len = 800, 400
154
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
155
+ longest_edge = int(shortest_edge * aspect_ratio)
156
+ W, H = image.size
157
+ if H > W:
158
+ H, W = longest_edge, shortest_edge
159
+ else:
160
+ H, W = shortest_edge, longest_edge
161
+ image = image.resize((W, H))
162
+ if return_pil:
163
+ images.append(image)
164
+ else:
165
+ buffered = BytesIO()
166
+ image.save(buffered, format="PNG")
167
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
168
+ images.append(img_b64_str)
169
+ return images
170
+
171
+ def to_gradio_chatbot(self):
172
+ ret = []
173
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
174
+ if i % 2 == 0:
175
+ if type(msg) is tuple:
176
+ import base64
177
+ from io import BytesIO
178
+
179
+ msg, image, image_process_mode = msg
180
+ max_hw, min_hw = max(image.size), min(image.size)
181
+ aspect_ratio = max_hw / min_hw
182
+ max_len, min_len = 800, 400
183
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
184
+ longest_edge = int(shortest_edge * aspect_ratio)
185
+ W, H = image.size
186
+ if H > W:
187
+ H, W = longest_edge, shortest_edge
188
+ else:
189
+ H, W = shortest_edge, longest_edge
190
+ image = image.resize((W, H))
191
+ buffered = BytesIO()
192
+ image.save(buffered, format="JPEG")
193
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
194
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
195
+ ret.append([img_str, None])
196
+ msg = msg.replace("<image>", "").strip()
197
+ if len(msg) > 0:
198
+ ret.append([msg, None])
199
+ else:
200
+ ret.append([msg, None])
201
+ else:
202
+ ret[-1][-1] = msg
203
+ return ret
204
+
205
+ def copy(self):
206
+ return Conversation(
207
+ system=self.system,
208
+ roles=self.roles,
209
+ messages=[[x, y] for x, y in self.messages],
210
+ offset=self.offset,
211
+ sep_style=self.sep_style,
212
+ sep=self.sep,
213
+ sep2=self.sep2,
214
+ version=self.version,
215
+ )
216
+
217
+ def dict(self):
218
+ if len(self.get_images()) > 0:
219
+ return {
220
+ "system": self.system,
221
+ "roles": self.roles,
222
+ "messages": [
223
+ [x, y[0] if type(y) is tuple else y] for x, y in self.messages
224
+ ],
225
+ "offset": self.offset,
226
+ "sep": self.sep,
227
+ "sep2": self.sep2,
228
+ }
229
+ return {
230
+ "system": self.system,
231
+ "roles": self.roles,
232
+ "messages": self.messages,
233
+ "offset": self.offset,
234
+ "sep": self.sep,
235
+ "sep2": self.sep2,
236
+ }
237
+
238
+
239
+ conv_vicuna_v0 = Conversation(
240
+ system="A chat between a curious human and an artificial intelligence assistant. "
241
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
242
+ roles=("Human", "Assistant"),
243
+ messages=(
244
+ (
245
+ "Human",
246
+ "What are the key differences between renewable and non-renewable energy sources?",
247
+ ),
248
+ (
249
+ "Assistant",
250
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
251
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
252
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
253
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
254
+ "renewable and non-renewable energy sources:\n"
255
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
256
+ "energy sources are finite and will eventually run out.\n"
257
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
258
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
259
+ "and other negative effects.\n"
260
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
261
+ "have lower operational costs than non-renewable sources.\n"
262
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
263
+ "locations than non-renewable sources.\n"
264
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
265
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
266
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
267
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
268
+ ),
269
+ ),
270
+ offset=2,
271
+ sep_style=SeparatorStyle.SINGLE,
272
+ sep="###",
273
+ )
274
+
275
+ conv_vicuna_v1 = Conversation(
276
+ system="A chat between a curious user and an artificial intelligence assistant. "
277
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
278
+ roles=("USER", "ASSISTANT"),
279
+ version="v1",
280
+ messages=(),
281
+ offset=0,
282
+ sep_style=SeparatorStyle.TWO,
283
+ sep=" ",
284
+ sep2="</s>",
285
+ )
286
+
287
+ conv_llama_2 = Conversation(
288
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
289
+
290
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
291
+ roles=("USER", "ASSISTANT"),
292
+ version="llama_v2",
293
+ messages=(),
294
+ offset=0,
295
+ sep_style=SeparatorStyle.LLAMA_2,
296
+ sep="<s>",
297
+ sep2="</s>",
298
+ )
299
+
300
+ conv_llava_llama_2 = Conversation(
301
+ system="You are a helpful language and vision assistant. "
302
+ "You are able to understand the visual content that the user provides, "
303
+ "and assist the user with a variety of tasks using natural language.",
304
+ roles=("USER", "ASSISTANT"),
305
+ version="llama_v2",
306
+ messages=(),
307
+ offset=0,
308
+ sep_style=SeparatorStyle.LLAMA_2,
309
+ sep="<s>",
310
+ sep2="</s>",
311
+ )
312
+
313
+ conv_mpt = Conversation(
314
+ system="""<|im_start|>system
315
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
316
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
317
+ version="mpt",
318
+ messages=(),
319
+ offset=0,
320
+ sep_style=SeparatorStyle.MPT,
321
+ sep="<|im_end|>",
322
+ )
323
+
324
+ conv_llava_plain = Conversation(
325
+ system="",
326
+ roles=("", ""),
327
+ messages=(),
328
+ offset=0,
329
+ sep_style=SeparatorStyle.PLAIN,
330
+ sep="\n",
331
+ )
332
+
333
+ conv_llava_v0 = Conversation(
334
+ system="A chat between a curious human and an artificial intelligence assistant. "
335
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
336
+ roles=("Human", "Assistant"),
337
+ messages=(("Human", "Hi!"), ("Assistant", "Hi there! How can I help you today?")),
338
+ offset=2,
339
+ sep_style=SeparatorStyle.SINGLE,
340
+ sep="###",
341
+ )
342
+
343
+ conv_llava_v0_mmtag = Conversation(
344
+ system="A chat between a curious user and an artificial intelligence assistant. "
345
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
346
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
347
+ roles=("Human", "Assistant"),
348
+ messages=(),
349
+ offset=0,
350
+ sep_style=SeparatorStyle.SINGLE,
351
+ sep="###",
352
+ version="v0_mmtag",
353
+ )
354
+
355
+ conv_llava_v1 = Conversation(
356
+ system="A chat between a curious human and an artificial intelligence assistant. "
357
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
358
+ roles=("USER", "ASSISTANT"),
359
+ version="v1",
360
+ messages=(),
361
+ offset=0,
362
+ sep_style=SeparatorStyle.TWO,
363
+ sep=" ",
364
+ sep2="</s>",
365
+ )
366
+
367
+ conv_llava_v1_mmtag = Conversation(
368
+ system="A chat between a curious user and an artificial intelligence assistant. "
369
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
370
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
371
+ roles=("USER", "ASSISTANT"),
372
+ messages=(),
373
+ offset=0,
374
+ sep_style=SeparatorStyle.TWO,
375
+ sep=" ",
376
+ sep2="</s>",
377
+ version="v1_mmtag",
378
+ )
379
+
380
+ default_conversation = conv_vicuna_v0
381
+ conv_templates = {
382
+ "default": conv_vicuna_v0,
383
+ "v0": conv_vicuna_v0,
384
+ "v1": conv_vicuna_v1,
385
+ "vicuna_v1": conv_vicuna_v1,
386
+ "llama_2": conv_llama_2,
387
+ "plain": conv_llava_plain,
388
+ "v0_plain": conv_llava_plain,
389
+ "llava_v0": conv_llava_v0,
390
+ "v0_mmtag": conv_llava_v0_mmtag,
391
+ "llava_v1": conv_llava_v1,
392
+ "v1_mmtag": conv_llava_v1_mmtag,
393
+ "llava_llama_2": conv_llava_llama_2,
394
+ "mpt": conv_mpt,
395
+ }
396
+
397
+
398
+ if __name__ == "__main__":
399
+ print(default_conversation.get_prompt())
lisa_on_cuda/llava/mm_utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from io import BytesIO
3
+
4
+ import torch
5
+ from PIL import Image
6
+ from transformers import StoppingCriteria
7
+
8
+ from .constants import IMAGE_TOKEN_INDEX
9
+
10
+
11
+ def load_image_from_base64(image):
12
+ return Image.open(BytesIO(base64.b64decode(image)))
13
+
14
+
15
+ def process_images(images, image_processor, model_cfg):
16
+ return image_processor(images, return_tensors="pt")["pixel_values"]
17
+
18
+
19
+ def tokenizer_image_token(
20
+ prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None
21
+ ):
22
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
23
+
24
+ def insert_separator(X, sep):
25
+ return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
26
+
27
+ input_ids = []
28
+ offset = 0
29
+ if (
30
+ len(prompt_chunks) > 0
31
+ and len(prompt_chunks[0]) > 0
32
+ and prompt_chunks[0][0] == tokenizer.bos_token_id
33
+ ):
34
+ offset = 1
35
+ input_ids.append(prompt_chunks[0][0])
36
+
37
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
38
+ input_ids.extend(x[offset:])
39
+
40
+ if return_tensors is not None:
41
+ if return_tensors == "pt":
42
+ return torch.tensor(input_ids, dtype=torch.long)
43
+ raise ValueError(f"Unsupported tensor type: {return_tensors}")
44
+ return input_ids
45
+
46
+
47
+ def get_model_name_from_path(model_path):
48
+ model_path = model_path.strip("/")
49
+ model_paths = model_path.split("/")
50
+ if model_paths[-1].startswith("checkpoint-"):
51
+ return model_paths[-2] + "_" + model_paths[-1]
52
+ else:
53
+ return model_paths[-1]
54
+
55
+
56
+ class KeywordsStoppingCriteria(StoppingCriteria):
57
+ def __init__(self, keywords, tokenizer, input_ids):
58
+ self.keywords = keywords
59
+ self.keyword_ids = []
60
+ for keyword in keywords:
61
+ cur_keyword_ids = tokenizer(keyword).input_ids
62
+ if (
63
+ len(cur_keyword_ids) > 1
64
+ and cur_keyword_ids[0] == tokenizer.bos_token_id
65
+ ):
66
+ cur_keyword_ids = cur_keyword_ids[1:]
67
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
68
+ self.tokenizer = tokenizer
69
+ self.start_len = input_ids.shape[1]
70
+
71
+ def __call__(
72
+ self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
73
+ ) -> bool:
74
+ assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
75
+ offset = min(output_ids.shape[1] - self.start_len, 3)
76
+ self.keyword_ids = [
77
+ keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids
78
+ ]
79
+ for keyword_id in self.keyword_ids:
80
+ if output_ids[0, -keyword_id.shape[0] :] == keyword_id:
81
+ return True
82
+ outputs = self.tokenizer.batch_decode(
83
+ output_ids[:, -offset:], skip_special_tokens=True
84
+ )[0]
85
+ for keyword in self.keywords:
86
+ if keyword in outputs:
87
+ return True
88
+ return False
lisa_on_cuda/llava/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .language_model.llava_llama import LlavaConfig, LlavaLlamaForCausalLM
2
+ from .language_model.llava_mpt import LlavaMPTConfig, LlavaMPTForCausalLM
lisa_on_cuda/llava/model/apply_delta.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+
11
+ from .language_model.llava_llama import LlavaLlamaForCausalLM
12
+
13
+
14
+ def apply_delta(base_model_path, target_model_path, delta_path):
15
+ print("Loading base model")
16
+ base = AutoModelForCausalLM.from_pretrained(
17
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
18
+ )
19
+
20
+ print("Loading delta")
21
+ delta = LlavaLlamaForCausalLM.from_pretrained(
22
+ delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
23
+ )
24
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
25
+
26
+ print("Applying delta")
27
+ for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
28
+ if name not in base.state_dict():
29
+ assert name in [
30
+ "model.mm_projector.weight",
31
+ "model.mm_projector.bias",
32
+ ], f"{name} not in base model"
33
+ continue
34
+ if param.data.shape == base.state_dict()[name].shape:
35
+ param.data += base.state_dict()[name]
36
+ else:
37
+ assert name in [
38
+ "model.embed_tokens.weight",
39
+ "lm_head.weight",
40
+ ], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
41
+ bparam = base.state_dict()[name]
42
+ param.data[: bparam.shape[0], : bparam.shape[1]] += bparam
43
+
44
+ print("Saving target model")
45
+ delta.save_pretrained(target_model_path)
46
+ delta_tokenizer.save_pretrained(target_model_path)
47
+
48
+
49
+ if __name__ == "__main__":
50
+ parser = argparse.ArgumentParser()
51
+ parser.add_argument("--base-model-path", type=str, required=True)
52
+ parser.add_argument("--target-model-path", type=str, required=True)
53
+ parser.add_argument("--delta-path", type=str, required=True)
54
+
55
+ args = parser.parse_args()
56
+
57
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
lisa_on_cuda/llava/model/builder.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import os
17
+ import shutil
18
+
19
+ import torch
20
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
21
+
22
+ from .language_model.llava_llama import LlavaLlamaForCausalLM
23
+ from .language_model.llava_mpt import LlavaMPTForCausalLM
24
+ from ..constants import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN
25
+
26
+
27
+ def load_pretrained_model(
28
+ model_path,
29
+ model_base,
30
+ model_name,
31
+ load_8bit=False,
32
+ load_4bit=False,
33
+ device_map="auto",
34
+ ):
35
+ kwargs = {"device_map": device_map}
36
+
37
+ if load_8bit:
38
+ kwargs["load_in_8bit"] = True
39
+ elif load_4bit:
40
+ kwargs["load_in_4bit"] = True
41
+ kwargs["quantization_config"] = BitsAndBytesConfig(
42
+ load_in_4bit=True,
43
+ bnb_4bit_compute_dtype=torch.float16,
44
+ bnb_4bit_use_double_quant=True,
45
+ bnb_4bit_quant_type="nf4",
46
+ )
47
+ else:
48
+ kwargs["torch_dtype"] = torch.float16
49
+
50
+ if "llava" in model_name.lower():
51
+ # Load LLaVA model
52
+ if "lora" in model_name.lower() and model_base is not None:
53
+ lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
54
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
55
+ print("Loading LLaVA from base model...")
56
+ model = LlavaLlamaForCausalLM.from_pretrained(
57
+ model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs
58
+ )
59
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
60
+ if model.lm_head.weight.shape[0] != token_num:
61
+ model.lm_head.weight = torch.nn.Parameter(
62
+ torch.empty(
63
+ token_num, tokem_dim, device=model.device, dtype=model.dtype
64
+ )
65
+ )
66
+ model.model.embed_tokens.weight = torch.nn.Parameter(
67
+ torch.empty(
68
+ token_num, tokem_dim, device=model.device, dtype=model.dtype
69
+ )
70
+ )
71
+
72
+ print("Loading additional LLaVA weights...")
73
+ if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
74
+ non_lora_trainables = torch.load(
75
+ os.path.join(model_path, "non_lora_trainables.bin"),
76
+ map_location="cpu",
77
+ )
78
+ else:
79
+ # this is probably from HF Hub
80
+ from huggingface_hub import hf_hub_download
81
+
82
+ def load_from_hf(repo_id, filename, subfolder=None):
83
+ cache_file = hf_hub_download(
84
+ repo_id=repo_id, filename=filename, subfolder=subfolder
85
+ )
86
+ return torch.load(cache_file, map_location="cpu")
87
+
88
+ non_lora_trainables = load_from_hf(
89
+ model_path, "non_lora_trainables.bin"
90
+ )
91
+ non_lora_trainables = {
92
+ (k[11:] if k.startswith("base_model.") else k): v
93
+ for k, v in non_lora_trainables.items()
94
+ }
95
+ if any(k.startswith("model.model.") for k in non_lora_trainables):
96
+ non_lora_trainables = {
97
+ (k[6:] if k.startswith("model.") else k): v
98
+ for k, v in non_lora_trainables.items()
99
+ }
100
+ model.load_state_dict(non_lora_trainables, strict=False)
101
+
102
+ from peft import PeftModel
103
+
104
+ print("Loading LoRA weights...")
105
+ model = PeftModel.from_pretrained(model, model_path)
106
+ print("Merging LoRA weights...")
107
+ model = model.merge_and_unload()
108
+ print("Model is loaded...")
109
+ elif model_base is not None:
110
+ # this may be mm projector only
111
+ print("Loading LLaVA from base model...")
112
+ if "mpt" in model_name.lower():
113
+ if not os.path.isfile(os.path.join(model_path, "configuration_mpt.py")):
114
+ shutil.copyfile(
115
+ os.path.join(model_base, "configuration_mpt.py"),
116
+ os.path.join(model_path, "configuration_mpt.py"),
117
+ )
118
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
119
+ cfg_pretrained = AutoConfig.from_pretrained(
120
+ model_path, trust_remote_code=True
121
+ )
122
+ model = LlavaMPTForCausalLM.from_pretrained(
123
+ model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs
124
+ )
125
+ else:
126
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
127
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
128
+ model = LlavaLlamaForCausalLM.from_pretrained(
129
+ model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs
130
+ )
131
+
132
+ mm_projector_weights = torch.load(
133
+ os.path.join(model_path, "mm_projector.bin"), map_location="cpu"
134
+ )
135
+ mm_projector_weights = {
136
+ k: v.to(torch.float16) for k, v in mm_projector_weights.items()
137
+ }
138
+ model.load_state_dict(mm_projector_weights, strict=False)
139
+ else:
140
+ if "mpt" in model_name.lower():
141
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
142
+ model = LlavaMPTForCausalLM.from_pretrained(
143
+ model_path, low_cpu_mem_usage=True, **kwargs
144
+ )
145
+ else:
146
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
147
+ model = LlavaLlamaForCausalLM.from_pretrained(
148
+ model_path, low_cpu_mem_usage=True, **kwargs
149
+ )
150
+ else:
151
+ # Load language model
152
+ if model_base is not None:
153
+ # PEFT model
154
+ from peft import PeftModel
155
+
156
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
157
+ model = AutoModelForCausalLM.from_pretrained(
158
+ model_base,
159
+ torch_dtype=torch.float16,
160
+ low_cpu_mem_usage=True,
161
+ device_map="auto",
162
+ )
163
+ print(f"Loading LoRA weights from {model_path}")
164
+ model = PeftModel.from_pretrained(model, model_path)
165
+ print(f"Merging weights")
166
+ model = model.merge_and_unload()
167
+ print("Convert to FP16...")
168
+ model.to(torch.float16)
169
+ else:
170
+ use_fast = False
171
+ if "mpt" in model_name.lower():
172
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
173
+ model = AutoModelForCausalLM.from_pretrained(
174
+ model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs
175
+ )
176
+ else:
177
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
178
+ model = AutoModelForCausalLM.from_pretrained(
179
+ model_path, low_cpu_mem_usage=True, **kwargs
180
+ )
181
+
182
+ image_processor = None
183
+
184
+ if "llava" in model_name.lower():
185
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
186
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
187
+ if mm_use_im_patch_token:
188
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
189
+ if mm_use_im_start_end:
190
+ tokenizer.add_tokens(
191
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
192
+ )
193
+ model.resize_token_embeddings(len(tokenizer))
194
+
195
+ vision_tower = model.get_vision_tower()
196
+ if not vision_tower.is_loaded:
197
+ vision_tower.load_model()
198
+ vision_tower.to(device="cuda", dtype=torch.float16)
199
+ image_processor = vision_tower.image_processor
200
+
201
+ if hasattr(model.config, "max_sequence_length"):
202
+ context_len = model.config.max_sequence_length
203
+ else:
204
+ context_len = 2048
205
+
206
+ return tokenizer, model, image_processor, context_len
lisa_on_cuda/llava/model/consolidate.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from .utils import auto_upgrade
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+
11
+
12
+ def consolidate_ckpt(src_path, dst_path):
13
+ print("Loading model")
14
+ auto_upgrade(src_path)
15
+ src_model = AutoModelForCausalLM.from_pretrained(
16
+ src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
17
+ )
18
+ src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
19
+ src_model.save_pretrained(dst_path)
20
+ src_tokenizer.save_pretrained(dst_path)
21
+
22
+
23
+ if __name__ == "__main__":
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument("--src", type=str, required=True)
26
+ parser.add_argument("--dst", type=str, required=True)
27
+
28
+ args = parser.parse_args()
29
+
30
+ consolidate_ckpt(args.src, args.dst)
lisa_on_cuda/llava/model/language_model/llava_llama.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+ from transformers import (AutoConfig, AutoModelForCausalLM, LlamaConfig,
22
+ LlamaForCausalLM, LlamaModel)
23
+ from transformers.modeling_outputs import CausalLMOutputWithPast
24
+
25
+ from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel
26
+
27
+
28
+ class LlavaConfig(LlamaConfig):
29
+ model_type = "llava"
30
+
31
+
32
+ class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
33
+ config_class = LlavaConfig
34
+
35
+ def __init__(self, config: LlamaConfig):
36
+ super(LlavaLlamaModel, self).__init__(config)
37
+
38
+
39
+ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
40
+ config_class = LlavaConfig
41
+
42
+ def __init__(self, config):
43
+ super(LlamaForCausalLM, self).__init__(config)
44
+
45
+ self.model = LlavaLlamaModel(config)
46
+
47
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
48
+
49
+ # Initialize weights and apply final processing
50
+ self.post_init()
51
+
52
+ def get_model(self):
53
+ return self.model
54
+
55
+ def forward(
56
+ self,
57
+ input_ids: torch.LongTensor = None,
58
+ attention_mask: Optional[torch.Tensor] = None,
59
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
60
+ inputs_embeds: Optional[torch.FloatTensor] = None,
61
+ labels: Optional[torch.LongTensor] = None,
62
+ use_cache: Optional[bool] = None,
63
+ output_attentions: Optional[bool] = None,
64
+ output_hidden_states: Optional[bool] = None,
65
+ images: Optional[torch.FloatTensor] = None,
66
+ return_dict: Optional[bool] = None,
67
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
68
+ output_attentions = (
69
+ output_attentions
70
+ if output_attentions is not None
71
+ else self.config.output_attentions
72
+ )
73
+ output_hidden_states = (
74
+ output_hidden_states
75
+ if output_hidden_states is not None
76
+ else self.config.output_hidden_states
77
+ )
78
+ return_dict = (
79
+ return_dict if return_dict is not None else self.config.use_return_dict
80
+ )
81
+
82
+ (
83
+ input_ids,
84
+ attention_mask,
85
+ past_key_values,
86
+ inputs_embeds,
87
+ labels,
88
+ ) = self.prepare_inputs_labels_for_multimodal(
89
+ input_ids, attention_mask, past_key_values, labels, images
90
+ )
91
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
92
+
93
+ outputs = self.model(
94
+ input_ids=input_ids,
95
+ attention_mask=attention_mask,
96
+ past_key_values=past_key_values,
97
+ inputs_embeds=inputs_embeds,
98
+ use_cache=use_cache,
99
+ output_attentions=output_attentions,
100
+ output_hidden_states=output_hidden_states,
101
+ return_dict=return_dict,
102
+ )
103
+
104
+ hidden_states = outputs[0]
105
+ logits = self.lm_head(hidden_states)
106
+
107
+ loss = None
108
+ if labels is not None:
109
+ # Shift so that tokens < n predict n
110
+ shift_logits = logits[..., :-1, :].contiguous()
111
+ shift_labels = labels[..., 1:].contiguous()
112
+ # Flatten the tokens
113
+ loss_fct = CrossEntropyLoss()
114
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
115
+ shift_labels = shift_labels.view(-1)
116
+ # Enable model/pipeline parallelism
117
+ shift_labels = shift_labels.to(shift_logits.device)
118
+ loss = loss_fct(shift_logits, shift_labels)
119
+
120
+ if not return_dict:
121
+ output = (logits,) + outputs[1:]
122
+ return (loss,) + output if loss is not None else output
123
+
124
+ if self.training:
125
+ output_hidden_states = outputs.hidden_states
126
+ else:
127
+ output_hidden_states = hidden_states
128
+
129
+ return CausalLMOutputWithPast(
130
+ loss=loss,
131
+ logits=logits,
132
+ past_key_values=outputs.past_key_values,
133
+ hidden_states=output_hidden_states, # outputs.hidden_states,
134
+ attentions=outputs.attentions,
135
+ )
136
+
137
+ def prepare_inputs_for_generation(
138
+ self,
139
+ input_ids,
140
+ past_key_values=None,
141
+ attention_mask=None,
142
+ inputs_embeds=None,
143
+ images=None,
144
+ **kwargs
145
+ ):
146
+ if past_key_values:
147
+ input_ids = input_ids[:, -1:]
148
+
149
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
150
+ if inputs_embeds is not None and past_key_values is None:
151
+ model_inputs = {"inputs_embeds": inputs_embeds}
152
+ else:
153
+ model_inputs = {"input_ids": input_ids}
154
+
155
+ model_inputs.update(
156
+ {
157
+ "past_key_values": past_key_values,
158
+ "use_cache": kwargs.get("use_cache"),
159
+ "attention_mask": attention_mask,
160
+ "images": images,
161
+ }
162
+ )
163
+ return model_inputs
164
+
165
+
166
+ AutoConfig.register("llava", LlavaConfig)
167
+ AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
lisa_on_cuda/llava/model/language_model/llava_mpt.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import math
17
+ import warnings
18
+ from typing import List, Optional, Tuple
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from transformers import AutoConfig, AutoModelForCausalLM
23
+ from transformers.modeling_outputs import CausalLMOutputWithPast
24
+
25
+ from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel
26
+ from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel
27
+
28
+
29
+ class LlavaMPTConfig(MPTConfig):
30
+ model_type = "llava_mpt"
31
+
32
+
33
+ class LlavaMPTModel(LlavaMetaModel, MPTModel):
34
+ config_class = LlavaMPTConfig
35
+
36
+ def __init__(self, config: MPTConfig):
37
+ config.hidden_size = config.d_model
38
+ super(LlavaMPTModel, self).__init__(config)
39
+
40
+ def embed_tokens(self, x):
41
+ return self.wte(x)
42
+
43
+
44
+ class LlavaMPTForCausalLM(MPTForCausalLM, LlavaMetaForCausalLM):
45
+ config_class = LlavaMPTConfig
46
+ supports_gradient_checkpointing = True
47
+
48
+ def __init__(self, config):
49
+ super(MPTForCausalLM, self).__init__(config)
50
+
51
+ if not config.tie_word_embeddings:
52
+ raise ValueError("MPTForCausalLM only supports tied word embeddings")
53
+ self.transformer = LlavaMPTModel(config)
54
+ self.logit_scale = None
55
+ if config.logit_scale is not None:
56
+ logit_scale = config.logit_scale
57
+ if isinstance(logit_scale, str):
58
+ if logit_scale == "inv_sqrt_d_model":
59
+ logit_scale = 1 / math.sqrt(config.d_model)
60
+ else:
61
+ raise ValueError(
62
+ f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
63
+ )
64
+ self.logit_scale = logit_scale
65
+
66
+ def get_model(self):
67
+ return self.transformer
68
+
69
+ def _set_gradient_checkpointing(self, module, value=False):
70
+ if isinstance(module, LlavaMPTModel):
71
+ module.gradient_checkpointing = value
72
+
73
+ def forward(
74
+ self,
75
+ input_ids: torch.LongTensor,
76
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
77
+ attention_mask: Optional[torch.ByteTensor] = None,
78
+ prefix_mask: Optional[torch.ByteTensor] = None,
79
+ sequence_id: Optional[torch.LongTensor] = None,
80
+ labels: Optional[torch.LongTensor] = None,
81
+ return_dict: Optional[bool] = None,
82
+ output_attentions: Optional[bool] = None,
83
+ output_hidden_states: Optional[bool] = None,
84
+ use_cache: Optional[bool] = None,
85
+ images=None,
86
+ ):
87
+ return_dict = (
88
+ return_dict if return_dict is not None else self.config.return_dict
89
+ )
90
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
91
+
92
+ (
93
+ input_ids,
94
+ attention_mask,
95
+ past_key_values,
96
+ inputs_embeds,
97
+ labels,
98
+ ) = self.prepare_inputs_labels_for_multimodal(
99
+ input_ids, attention_mask, past_key_values, labels, images
100
+ )
101
+ outputs = self.transformer(
102
+ input_ids=input_ids,
103
+ inputs_embeds=inputs_embeds,
104
+ past_key_values=past_key_values,
105
+ attention_mask=attention_mask,
106
+ prefix_mask=prefix_mask,
107
+ sequence_id=sequence_id,
108
+ return_dict=return_dict,
109
+ output_attentions=output_attentions,
110
+ output_hidden_states=output_hidden_states,
111
+ use_cache=use_cache,
112
+ )
113
+ # FIXME: this is a hack to fix the multiple gpu inference issue in https://github.com/haotian-liu/LLaVA/issues/338
114
+ logits = F.linear(
115
+ outputs.last_hidden_state.to(self.transformer.wte.weight.device),
116
+ self.transformer.wte.weight,
117
+ )
118
+ if self.logit_scale is not None:
119
+ if self.logit_scale == 0:
120
+ warnings.warn(
121
+ f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs."
122
+ )
123
+ logits *= self.logit_scale
124
+ loss = None
125
+ if labels is not None:
126
+ labels = torch.roll(labels, shifts=-1)
127
+ labels[:, -1] = -100
128
+ loss = F.cross_entropy(
129
+ logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)
130
+ )
131
+ return CausalLMOutputWithPast(
132
+ loss=loss,
133
+ logits=logits,
134
+ past_key_values=outputs.past_key_values,
135
+ hidden_states=outputs.hidden_states,
136
+ )
137
+
138
+ def prepare_inputs_for_generation(
139
+ self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
140
+ ):
141
+ if inputs_embeds is not None:
142
+ raise NotImplementedError("inputs_embeds is not implemented for MPT yet")
143
+ attention_mask = kwargs["attention_mask"].bool()
144
+ if attention_mask[:, -1].sum() != attention_mask.shape[0]:
145
+ raise NotImplementedError(
146
+ "MPT does not support generation with right padding."
147
+ )
148
+ if self.transformer.attn_uses_sequence_id and self.training:
149
+ sequence_id = torch.zeros_like(input_ids[:1])
150
+ else:
151
+ sequence_id = None
152
+ if past_key_values is not None:
153
+ input_ids = input_ids[:, -1].unsqueeze(-1)
154
+ if self.transformer.prefix_lm:
155
+ prefix_mask = torch.ones_like(attention_mask)
156
+ if kwargs.get("use_cache") == False:
157
+ raise NotImplementedError(
158
+ "MPT with prefix_lm=True does not support use_cache=False."
159
+ )
160
+ else:
161
+ prefix_mask = None
162
+ return {
163
+ "input_ids": input_ids,
164
+ "attention_mask": attention_mask,
165
+ "prefix_mask": prefix_mask,
166
+ "sequence_id": sequence_id,
167
+ "past_key_values": past_key_values,
168
+ "use_cache": kwargs.get("use_cache", True),
169
+ "images": kwargs.get("images", None),
170
+ }
171
+
172
+
173
+ AutoConfig.register("llava_mpt", LlavaMPTConfig)
174
+ AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM)
lisa_on_cuda/llava/model/language_model/mpt/adapt_tokenizer.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ from transformers import (AutoTokenizer, PreTrainedTokenizer,
4
+ PreTrainedTokenizerFast)
5
+
6
+ Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
7
+ NUM_SENTINEL_TOKENS: int = 100
8
+
9
+
10
+ def adapt_tokenizer_for_denoising(tokenizer: Tokenizer):
11
+ """Adds sentinel tokens and padding token (if missing).
12
+
13
+ Expands the tokenizer vocabulary to include sentinel tokens
14
+ used in mixture-of-denoiser tasks as well as a padding token.
15
+
16
+ All added tokens are added as special tokens. No tokens are
17
+ added if sentinel tokens and padding token already exist.
18
+ """
19
+ sentinels_to_add = [f"<extra_id_{i}>" for i in range(NUM_SENTINEL_TOKENS)]
20
+ tokenizer.add_tokens(sentinels_to_add, special_tokens=True)
21
+ if tokenizer.pad_token is None:
22
+ tokenizer.add_tokens("<pad>", special_tokens=True)
23
+ tokenizer.pad_token = "<pad>"
24
+ assert tokenizer.pad_token_id is not None
25
+ sentinels = "".join([f"<extra_id_{i}>" for i in range(NUM_SENTINEL_TOKENS)])
26
+ _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids
27
+ tokenizer.sentinel_token_ids = _sentinel_token_ids
28
+
29
+
30
+ class AutoTokenizerForMOD(AutoTokenizer):
31
+ """AutoTokenizer + Adaptation for MOD.
32
+
33
+ A simple wrapper around AutoTokenizer to make instantiating
34
+ an MOD-adapted tokenizer a bit easier.
35
+
36
+ MOD-adapted tokenizers have sentinel tokens (e.g., <extra_id_0>),
37
+ a padding token, and a property to get the token ids of the
38
+ sentinel tokens.
39
+ """
40
+
41
+ @classmethod
42
+ def from_pretrained(cls, *args, **kwargs):
43
+ """See `AutoTokenizer.from_pretrained` docstring."""
44
+ tokenizer = super().from_pretrained(*args, **kwargs)
45
+ adapt_tokenizer_for_denoising(tokenizer)
46
+ return tokenizer
lisa_on_cuda/llava/model/language_model/mpt/attention.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Attention layers."""
2
+ import math
3
+ import warnings
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from packaging import version
10
+ from torch import nn
11
+
12
+ from .norm import LPLayerNorm
13
+
14
+
15
+ def _reset_is_causal(
16
+ num_query_tokens: int, num_key_tokens: int, original_is_causal: bool
17
+ ):
18
+ if original_is_causal and num_query_tokens != num_key_tokens:
19
+ if num_query_tokens != 1:
20
+ raise NotImplementedError(
21
+ "MPT does not support query and key with different number of tokens, unless number of query tokens is 1."
22
+ )
23
+ else:
24
+ return False
25
+ return original_is_causal
26
+
27
+
28
+ def scaled_multihead_dot_product_attention(
29
+ query,
30
+ key,
31
+ value,
32
+ n_heads,
33
+ past_key_value=None,
34
+ softmax_scale=None,
35
+ attn_bias=None,
36
+ key_padding_mask=None,
37
+ is_causal=False,
38
+ dropout_p=0.0,
39
+ training=False,
40
+ needs_weights=False,
41
+ multiquery=False,
42
+ ):
43
+ q = rearrange(query, "b s (h d) -> b h s d", h=n_heads)
44
+ kv_n_heads = 1 if multiquery else n_heads
45
+ k = rearrange(key, "b s (h d) -> b h d s", h=kv_n_heads)
46
+ v = rearrange(value, "b s (h d) -> b h s d", h=kv_n_heads)
47
+ if past_key_value is not None:
48
+ if len(past_key_value) != 0:
49
+ k = torch.cat([past_key_value[0], k], dim=3)
50
+ v = torch.cat([past_key_value[1], v], dim=2)
51
+ past_key_value = (k, v)
52
+ (b, _, s_q, d) = q.shape
53
+ s_k = k.size(-1)
54
+ if softmax_scale is None:
55
+ softmax_scale = 1 / math.sqrt(d)
56
+ attn_weight = q.matmul(k) * softmax_scale
57
+ if attn_bias is not None:
58
+ _s_q = max(0, attn_bias.size(2) - s_q)
59
+ _s_k = max(0, attn_bias.size(3) - s_k)
60
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
61
+ if (
62
+ attn_bias.size(-1) != 1
63
+ and attn_bias.size(-1) != s_k
64
+ or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q)
65
+ ):
66
+ raise RuntimeError(
67
+ f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}."
68
+ )
69
+ attn_weight = attn_weight + attn_bias
70
+ min_val = torch.finfo(q.dtype).min
71
+ if key_padding_mask is not None:
72
+ if attn_bias is not None:
73
+ warnings.warn(
74
+ "Propogating key_padding_mask to the attention module "
75
+ + "and applying it within the attention module can cause "
76
+ + "unneccessary computation/memory usage. Consider integrating "
77
+ + "into attn_bias once and passing that to each attention "
78
+ + "module instead."
79
+ )
80
+ attn_weight = attn_weight.masked_fill(
81
+ ~key_padding_mask.view((b, 1, 1, s_k)), min_val
82
+ )
83
+ if is_causal and (not q.size(2) == 1):
84
+ s = max(s_q, s_k)
85
+ causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
86
+ causal_mask = causal_mask.tril()
87
+ causal_mask = causal_mask.to(torch.bool)
88
+ causal_mask = ~causal_mask
89
+ causal_mask = causal_mask[-s_q:, -s_k:]
90
+ attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
91
+ attn_weight = torch.softmax(attn_weight, dim=-1)
92
+ if dropout_p:
93
+ attn_weight = torch.nn.functional.dropout(
94
+ attn_weight, p=dropout_p, training=training, inplace=True
95
+ )
96
+ out = attn_weight.to(v.dtype).matmul(v)
97
+ out = rearrange(out, "b h s d -> b s (h d)")
98
+ if needs_weights:
99
+ return (out, attn_weight, past_key_value)
100
+ return (out, None, past_key_value)
101
+
102
+
103
+ def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
104
+ for tensor in tensors:
105
+ if tensor.dtype not in valid_dtypes:
106
+ raise TypeError(
107
+ f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}."
108
+ )
109
+ if not tensor.is_cuda:
110
+ raise TypeError(
111
+ f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r})."
112
+ )
113
+
114
+
115
+ def flash_attn_fn(
116
+ query,
117
+ key,
118
+ value,
119
+ n_heads,
120
+ past_key_value=None,
121
+ softmax_scale=None,
122
+ attn_bias=None,
123
+ key_padding_mask=None,
124
+ is_causal=False,
125
+ dropout_p=0.0,
126
+ training=False,
127
+ needs_weights=False,
128
+ multiquery=False,
129
+ ):
130
+ try:
131
+ from flash_attn import bert_padding, flash_attn_interface
132
+ except:
133
+ raise RuntimeError("Please install flash-attn==1.0.3.post0")
134
+ check_valid_inputs(query, key, value)
135
+ if past_key_value is not None:
136
+ if len(past_key_value) != 0:
137
+ key = torch.cat([past_key_value[0], key], dim=1)
138
+ value = torch.cat([past_key_value[1], value], dim=1)
139
+ past_key_value = (key, value)
140
+ if attn_bias is not None:
141
+ _s_q = max(0, attn_bias.size(2) - query.size(1))
142
+ _s_k = max(0, attn_bias.size(3) - key.size(1))
143
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
144
+ if attn_bias is not None:
145
+ raise NotImplementedError(f"attn_bias not implemented for flash attn.")
146
+ (batch_size, seqlen) = query.shape[:2]
147
+ if key_padding_mask is None:
148
+ key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
149
+ query_padding_mask = key_padding_mask[:, -query.size(1) :]
150
+ (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(
151
+ query, query_padding_mask
152
+ )
153
+ query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads)
154
+ (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(
155
+ key, key_padding_mask
156
+ )
157
+ key_unpad = rearrange(
158
+ key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads
159
+ )
160
+ (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
161
+ value_unpad = rearrange(
162
+ value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads
163
+ )
164
+ if multiquery:
165
+ key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
166
+ value_unpad = value_unpad.expand(
167
+ value_unpad.size(0), n_heads, value_unpad.size(-1)
168
+ )
169
+ dropout_p = dropout_p if training else 0.0
170
+ reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
171
+ output_unpad = flash_attn_interface.flash_attn_unpadded_func(
172
+ query_unpad,
173
+ key_unpad,
174
+ value_unpad,
175
+ cu_seqlens_q,
176
+ cu_seqlens_k,
177
+ max_seqlen_q,
178
+ max_seqlen_k,
179
+ dropout_p,
180
+ softmax_scale=softmax_scale,
181
+ causal=reset_is_causal,
182
+ return_attn_probs=needs_weights,
183
+ )
184
+ output = bert_padding.pad_input(
185
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen
186
+ )
187
+ return (output, None, past_key_value)
188
+
189
+
190
+ def triton_flash_attn_fn(
191
+ query,
192
+ key,
193
+ value,
194
+ n_heads,
195
+ past_key_value=None,
196
+ softmax_scale=None,
197
+ attn_bias=None,
198
+ key_padding_mask=None,
199
+ is_causal=False,
200
+ dropout_p=0.0,
201
+ training=False,
202
+ needs_weights=False,
203
+ multiquery=False,
204
+ ):
205
+ try:
206
+ from .flash_attn_triton import flash_attn_func
207
+ except:
208
+ _installed = False
209
+ if version.parse(torch.__version__) < version.parse("2.0.0"):
210
+ _installed = True
211
+ try:
212
+ from flash_attn.flash_attn_triton import flash_attn_func
213
+ except:
214
+ _installed = False
215
+ if not _installed:
216
+ raise RuntimeError(
217
+ "Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed."
218
+ )
219
+ check_valid_inputs(query, key, value)
220
+ if past_key_value is not None:
221
+ if len(past_key_value) != 0:
222
+ key = torch.cat([past_key_value[0], key], dim=1)
223
+ value = torch.cat([past_key_value[1], value], dim=1)
224
+ past_key_value = (key, value)
225
+ if attn_bias is not None:
226
+ _s_q = max(0, attn_bias.size(2) - query.size(1))
227
+ _s_k = max(0, attn_bias.size(3) - key.size(1))
228
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
229
+ if dropout_p:
230
+ raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.")
231
+ if needs_weights:
232
+ raise NotImplementedError(f"attn_impl: triton cannot return attn weights.")
233
+ if key_padding_mask is not None:
234
+ warnings.warn(
235
+ "Propagating key_padding_mask to the attention module "
236
+ + "and applying it within the attention module can cause "
237
+ + "unnecessary computation/memory usage. Consider integrating "
238
+ + "into attn_bias once and passing that to each attention "
239
+ + "module instead."
240
+ )
241
+ (b_size, s_k) = key_padding_mask.shape[:2]
242
+ if attn_bias is None:
243
+ attn_bias = query.new_zeros(b_size, 1, 1, s_k)
244
+ attn_bias = attn_bias.masked_fill(
245
+ ~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min
246
+ )
247
+ query = rearrange(query, "b s (h d) -> b s h d", h=n_heads)
248
+ key = rearrange(key, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
249
+ value = rearrange(value, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
250
+ if multiquery:
251
+ key = key.expand(*key.shape[:2], n_heads, key.size(-1))
252
+ value = value.expand(*value.shape[:2], n_heads, value.size(-1))
253
+ reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
254
+ attn_output = flash_attn_func(
255
+ query, key, value, attn_bias, reset_is_causal, softmax_scale
256
+ )
257
+ output = attn_output.view(*attn_output.shape[:2], -1)
258
+ return (output, None, past_key_value)
259
+
260
+
261
+ class MultiheadAttention(nn.Module):
262
+ """Multi-head self attention.
263
+
264
+ Using torch or triton attention implemetation enables user to also use
265
+ additive bias.
266
+ """
267
+
268
+ def __init__(
269
+ self,
270
+ d_model: int,
271
+ n_heads: int,
272
+ attn_impl: str = "triton",
273
+ clip_qkv: Optional[float] = None,
274
+ qk_ln: bool = False,
275
+ softmax_scale: Optional[float] = None,
276
+ attn_pdrop: float = 0.0,
277
+ low_precision_layernorm: bool = False,
278
+ verbose: int = 0,
279
+ device: Optional[str] = None,
280
+ ):
281
+ super().__init__()
282
+ self.attn_impl = attn_impl
283
+ self.clip_qkv = clip_qkv
284
+ self.qk_ln = qk_ln
285
+ self.d_model = d_model
286
+ self.n_heads = n_heads
287
+ self.softmax_scale = softmax_scale
288
+ if self.softmax_scale is None:
289
+ self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
290
+ self.attn_dropout_p = attn_pdrop
291
+ self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device)
292
+ fuse_splits = (d_model, 2 * d_model)
293
+ self.Wqkv._fused = (0, fuse_splits)
294
+ if self.qk_ln:
295
+ layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
296
+ self.q_ln = layernorm_class(self.d_model, device=device)
297
+ self.k_ln = layernorm_class(self.d_model, device=device)
298
+ if self.attn_impl == "flash":
299
+ self.attn_fn = flash_attn_fn
300
+ elif self.attn_impl == "triton":
301
+ self.attn_fn = triton_flash_attn_fn
302
+ if verbose:
303
+ warnings.warn(
304
+ "While `attn_impl: triton` can be faster than `attn_impl: flash` "
305
+ + "it uses more memory. When training larger models this can trigger "
306
+ + "alloc retries which hurts performance. If encountered, we recommend "
307
+ + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
308
+ )
309
+ elif self.attn_impl == "torch":
310
+ self.attn_fn = scaled_multihead_dot_product_attention
311
+ if torch.cuda.is_available() and verbose:
312
+ warnings.warn(
313
+ "Using `attn_impl: torch`. If your model does not use `alibi` or "
314
+ + "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
315
+ + "we recommend using `attn_impl: triton`."
316
+ )
317
+ else:
318
+ raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
319
+ self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
320
+ self.out_proj._is_residual = True
321
+
322
+ def forward(
323
+ self,
324
+ x,
325
+ past_key_value=None,
326
+ attn_bias=None,
327
+ attention_mask=None,
328
+ is_causal=True,
329
+ needs_weights=False,
330
+ ):
331
+ qkv = self.Wqkv(x)
332
+ if self.clip_qkv:
333
+ qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
334
+ (query, key, value) = qkv.chunk(3, dim=2)
335
+ key_padding_mask = attention_mask
336
+ if self.qk_ln:
337
+ dtype = query.dtype
338
+ query = self.q_ln(query).to(dtype)
339
+ key = self.k_ln(key).to(dtype)
340
+ (context, attn_weights, past_key_value) = self.attn_fn(
341
+ query,
342
+ key,
343
+ value,
344
+ self.n_heads,
345
+ past_key_value=past_key_value,
346
+ softmax_scale=self.softmax_scale,
347
+ attn_bias=attn_bias,
348
+ key_padding_mask=key_padding_mask,
349
+ is_causal=is_causal,
350
+ dropout_p=self.attn_dropout_p,
351
+ training=self.training,
352
+ needs_weights=needs_weights,
353
+ )
354
+ return (self.out_proj(context), attn_weights, past_key_value)
355
+
356
+
357
+ class MultiQueryAttention(nn.Module):
358
+ """Multi-Query self attention.
359
+
360
+ Using torch or triton attention implemetation enables user to also use
361
+ additive bias.
362
+ """
363
+
364
+ def __init__(
365
+ self,
366
+ d_model: int,
367
+ n_heads: int,
368
+ attn_impl: str = "triton",
369
+ clip_qkv: Optional[float] = None,
370
+ qk_ln: bool = False,
371
+ softmax_scale: Optional[float] = None,
372
+ attn_pdrop: float = 0.0,
373
+ low_precision_layernorm: bool = False,
374
+ verbose: int = 0,
375
+ device: Optional[str] = None,
376
+ ):
377
+ super().__init__()
378
+ self.attn_impl = attn_impl
379
+ self.clip_qkv = clip_qkv
380
+ self.qk_ln = qk_ln
381
+ self.d_model = d_model
382
+ self.n_heads = n_heads
383
+ self.head_dim = d_model // n_heads
384
+ self.softmax_scale = softmax_scale
385
+ if self.softmax_scale is None:
386
+ self.softmax_scale = 1 / math.sqrt(self.head_dim)
387
+ self.attn_dropout_p = attn_pdrop
388
+ self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
389
+ fuse_splits = (d_model, d_model + self.head_dim)
390
+ self.Wqkv._fused = (0, fuse_splits)
391
+ if self.qk_ln:
392
+ layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
393
+ self.q_ln = layernorm_class(d_model, device=device)
394
+ self.k_ln = layernorm_class(self.head_dim, device=device)
395
+ if self.attn_impl == "flash":
396
+ self.attn_fn = flash_attn_fn
397
+ elif self.attn_impl == "triton":
398
+ self.attn_fn = triton_flash_attn_fn
399
+ if verbose:
400
+ warnings.warn(
401
+ "While `attn_impl: triton` can be faster than `attn_impl: flash` "
402
+ + "it uses more memory. When training larger models this can trigger "
403
+ + "alloc retries which hurts performance. If encountered, we recommend "
404
+ + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
405
+ )
406
+ elif self.attn_impl == "torch":
407
+ self.attn_fn = scaled_multihead_dot_product_attention
408
+ if torch.cuda.is_available() and verbose:
409
+ warnings.warn(
410
+ "Using `attn_impl: torch`. If your model does not use `alibi` or "
411
+ + "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
412
+ + "we recommend using `attn_impl: triton`."
413
+ )
414
+ else:
415
+ raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
416
+ self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
417
+ self.out_proj._is_residual = True
418
+
419
+ def forward(
420
+ self,
421
+ x,
422
+ past_key_value=None,
423
+ attn_bias=None,
424
+ attention_mask=None,
425
+ is_causal=True,
426
+ needs_weights=False,
427
+ ):
428
+ qkv = self.Wqkv(x)
429
+ if self.clip_qkv:
430
+ qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
431
+ (query, key, value) = qkv.split(
432
+ [self.d_model, self.head_dim, self.head_dim], dim=2
433
+ )
434
+ key_padding_mask = attention_mask
435
+ if self.qk_ln:
436
+ dtype = query.dtype
437
+ query = self.q_ln(query).to(dtype)
438
+ key = self.k_ln(key).to(dtype)
439
+ (context, attn_weights, past_key_value) = self.attn_fn(
440
+ query,
441
+ key,
442
+ value,
443
+ self.n_heads,
444
+ past_key_value=past_key_value,
445
+ softmax_scale=self.softmax_scale,
446
+ attn_bias=attn_bias,
447
+ key_padding_mask=key_padding_mask,
448
+ is_causal=is_causal,
449
+ dropout_p=self.attn_dropout_p,
450
+ training=self.training,
451
+ needs_weights=needs_weights,
452
+ multiquery=True,
453
+ )
454
+ return (self.out_proj(context), attn_weights, past_key_value)
455
+
456
+
457
+ def attn_bias_shape(
458
+ attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id
459
+ ):
460
+ if attn_impl == "flash":
461
+ return None
462
+ elif attn_impl in ["torch", "triton"]:
463
+ if alibi:
464
+ if (prefix_lm or not causal) or use_sequence_id:
465
+ return (1, n_heads, seq_len, seq_len)
466
+ return (1, n_heads, 1, seq_len)
467
+ elif prefix_lm or use_sequence_id:
468
+ return (1, 1, seq_len, seq_len)
469
+ return None
470
+ else:
471
+ raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
472
+
473
+
474
+ def build_attn_bias(
475
+ attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8
476
+ ):
477
+ if attn_impl == "flash":
478
+ return None
479
+ elif attn_impl in ["torch", "triton"]:
480
+ if alibi:
481
+ (device, dtype) = (attn_bias.device, attn_bias.dtype)
482
+ attn_bias = attn_bias.add(
483
+ build_alibi_bias(
484
+ n_heads,
485
+ seq_len,
486
+ full=not causal,
487
+ alibi_bias_max=alibi_bias_max,
488
+ device=device,
489
+ dtype=dtype,
490
+ )
491
+ )
492
+ return attn_bias
493
+ else:
494
+ raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
495
+
496
+
497
+ def gen_slopes(n_heads, alibi_bias_max=8, device=None):
498
+ _n_heads = 2 ** math.ceil(math.log2(n_heads))
499
+ m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
500
+ m = m.mul(alibi_bias_max / _n_heads)
501
+ slopes = 1.0 / torch.pow(2, m)
502
+ if _n_heads != n_heads:
503
+ slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
504
+ return slopes.view(1, n_heads, 1, 1)
505
+
506
+
507
+ def build_alibi_bias(
508
+ n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None
509
+ ):
510
+ alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(
511
+ 1, 1, 1, seq_len
512
+ )
513
+ if full:
514
+ alibi_bias = alibi_bias - torch.arange(
515
+ 1 - seq_len, 1, dtype=torch.int32, device=device
516
+ ).view(1, 1, seq_len, 1)
517
+ alibi_bias = alibi_bias.abs().mul(-1)
518
+ slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
519
+ alibi_bias = alibi_bias * slopes
520
+ return alibi_bias.to(dtype=dtype)
521
+
522
+
523
+ ATTN_CLASS_REGISTRY = {
524
+ "multihead_attention": MultiheadAttention,
525
+ "multiquery_attention": MultiQueryAttention,
526
+ }
lisa_on_cuda/llava/model/language_model/mpt/blocks.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GPT Blocks used for the GPT Model."""
2
+ from typing import Dict, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from .attention import ATTN_CLASS_REGISTRY
8
+ from .norm import NORM_CLASS_REGISTRY
9
+
10
+
11
+ class MPTMLP(nn.Module):
12
+ def __init__(
13
+ self, d_model: int, expansion_ratio: int, device: Optional[str] = None
14
+ ):
15
+ super().__init__()
16
+ self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
17
+ self.act = nn.GELU(approximate="none")
18
+ self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
19
+ self.down_proj._is_residual = True
20
+
21
+ def forward(self, x):
22
+ return self.down_proj(self.act(self.up_proj(x)))
23
+
24
+
25
+ class MPTBlock(nn.Module):
26
+ def __init__(
27
+ self,
28
+ d_model: int,
29
+ n_heads: int,
30
+ expansion_ratio: int,
31
+ attn_config: Dict = {
32
+ "attn_type": "multihead_attention",
33
+ "attn_pdrop": 0.0,
34
+ "attn_impl": "triton",
35
+ "qk_ln": False,
36
+ "clip_qkv": None,
37
+ "softmax_scale": None,
38
+ "prefix_lm": False,
39
+ "attn_uses_sequence_id": False,
40
+ "alibi": False,
41
+ "alibi_bias_max": 8,
42
+ },
43
+ resid_pdrop: float = 0.0,
44
+ norm_type: str = "low_precision_layernorm",
45
+ verbose: int = 0,
46
+ device: Optional[str] = None,
47
+ **kwargs
48
+ ):
49
+ del kwargs
50
+ super().__init__()
51
+ norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
52
+ attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]]
53
+ self.norm_1 = norm_class(d_model, device=device)
54
+ self.attn = attn_class(
55
+ attn_impl=attn_config["attn_impl"],
56
+ clip_qkv=attn_config["clip_qkv"],
57
+ qk_ln=attn_config["qk_ln"],
58
+ softmax_scale=attn_config["softmax_scale"],
59
+ attn_pdrop=attn_config["attn_pdrop"],
60
+ d_model=d_model,
61
+ n_heads=n_heads,
62
+ verbose=verbose,
63
+ device=device,
64
+ )
65
+ self.norm_2 = norm_class(d_model, device=device)
66
+ self.ffn = MPTMLP(
67
+ d_model=d_model, expansion_ratio=expansion_ratio, device=device
68
+ )
69
+ self.resid_attn_dropout = nn.Dropout(resid_pdrop)
70
+ self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
71
+
72
+ def forward(
73
+ self,
74
+ x: torch.Tensor,
75
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
76
+ attn_bias: Optional[torch.Tensor] = None,
77
+ attention_mask: Optional[torch.ByteTensor] = None,
78
+ is_causal: bool = True,
79
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
80
+ a = self.norm_1(x)
81
+ (b, attn_weights, past_key_value) = self.attn(
82
+ a,
83
+ past_key_value=past_key_value,
84
+ attn_bias=attn_bias,
85
+ attention_mask=attention_mask,
86
+ is_causal=is_causal,
87
+ )
88
+ x = x + self.resid_attn_dropout(b)
89
+ m = self.norm_2(x)
90
+ n = self.ffn(m)
91
+ x = x + self.resid_ffn_dropout(n)
92
+ return (x, attn_weights, past_key_value)
lisa_on_cuda/llava/model/language_model/mpt/configuration_mpt.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A HuggingFace-style model configuration."""
2
+ from typing import Dict, Optional, Union
3
+
4
+ from transformers import PretrainedConfig
5
+
6
+ attn_config_defaults: Dict = {
7
+ "attn_type": "multihead_attention",
8
+ "attn_pdrop": 0.0,
9
+ "attn_impl": "triton",
10
+ "qk_ln": False,
11
+ "clip_qkv": None,
12
+ "softmax_scale": None,
13
+ "prefix_lm": False,
14
+ "attn_uses_sequence_id": False,
15
+ "alibi": False,
16
+ "alibi_bias_max": 8,
17
+ }
18
+ init_config_defaults: Dict = {
19
+ "name": "kaiming_normal_",
20
+ "fan_mode": "fan_in",
21
+ "init_nonlinearity": "relu",
22
+ "init_div_is_residual": True,
23
+ "emb_init_std": None,
24
+ "emb_init_uniform_lim": None,
25
+ "init_std": None,
26
+ "init_gain": 0.0,
27
+ }
28
+
29
+
30
+ class MPTConfig(PretrainedConfig):
31
+ model_type = "mpt"
32
+
33
+ def __init__(
34
+ self,
35
+ d_model: int = 2048,
36
+ n_heads: int = 16,
37
+ n_layers: int = 24,
38
+ expansion_ratio: int = 4,
39
+ max_seq_len: int = 2048,
40
+ vocab_size: int = 50368,
41
+ resid_pdrop: float = 0.0,
42
+ emb_pdrop: float = 0.0,
43
+ learned_pos_emb: bool = True,
44
+ attn_config: Dict = attn_config_defaults,
45
+ init_device: str = "cpu",
46
+ logit_scale: Optional[Union[float, str]] = None,
47
+ no_bias: bool = False,
48
+ verbose: int = 0,
49
+ embedding_fraction: float = 1.0,
50
+ norm_type: str = "low_precision_layernorm",
51
+ use_cache: bool = False,
52
+ init_config: Dict = init_config_defaults,
53
+ **kwargs,
54
+ ):
55
+ """The MPT configuration class.
56
+
57
+ Args:
58
+ d_model (int): The size of the embedding dimension of the model.
59
+ n_heads (int): The number of attention heads.
60
+ n_layers (int): The number of layers in the model.
61
+ expansion_ratio (int): The ratio of the up/down scale in the MLP.
62
+ max_seq_len (int): The maximum sequence length of the model.
63
+ vocab_size (int): The size of the vocabulary.
64
+ resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
65
+ emb_pdrop (float): The dropout probability for the embedding layer.
66
+ learned_pos_emb (bool): Whether to use learned positional embeddings
67
+ attn_config (Dict): A dictionary used to configure the model's attention module:
68
+ attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention
69
+ attn_pdrop (float): The dropout probability for the attention layers.
70
+ attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
71
+ qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
72
+ clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
73
+ this value.
74
+ softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
75
+ use the default scale of ``1/sqrt(d_keys)``.
76
+ prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
77
+ extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
78
+ can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
79
+ attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
80
+ When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
81
+ which sub-sequence each token belongs to.
82
+ Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
83
+ alibi (bool): Whether to use the alibi bias instead of position embeddings.
84
+ alibi_bias_max (int): The maximum value of the alibi bias.
85
+ init_device (str): The device to use for parameter initialization.
86
+ logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
87
+ no_bias (bool): Whether to use bias in all layers.
88
+ verbose (int): The verbosity level. 0 is silent.
89
+ embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
90
+ norm_type (str): choose type of norm to use
91
+ multiquery_attention (bool): Whether to use multiquery attention implementation.
92
+ use_cache (bool): Whether or not the model should return the last key/values attentions
93
+ init_config (Dict): A dictionary used to configure the model initialization:
94
+ init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',
95
+ 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or
96
+ 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch.
97
+ init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
98
+ emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
99
+ emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
100
+ used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
101
+ init_std (float): The standard deviation of the normal distribution used to initialize the model,
102
+ if using the baseline_ parameter initialization scheme.
103
+ init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
104
+ fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
105
+ init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
106
+ ---
107
+ See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
108
+ """
109
+ self.d_model = d_model
110
+ self.n_heads = n_heads
111
+ self.n_layers = n_layers
112
+ self.expansion_ratio = expansion_ratio
113
+ self.max_seq_len = max_seq_len
114
+ self.vocab_size = vocab_size
115
+ self.resid_pdrop = resid_pdrop
116
+ self.emb_pdrop = emb_pdrop
117
+ self.learned_pos_emb = learned_pos_emb
118
+ self.attn_config = attn_config
119
+ self.init_device = init_device
120
+ self.logit_scale = logit_scale
121
+ self.no_bias = no_bias
122
+ self.verbose = verbose
123
+ self.embedding_fraction = embedding_fraction
124
+ self.norm_type = norm_type
125
+ self.use_cache = use_cache
126
+ self.init_config = init_config
127
+ if "name" in kwargs:
128
+ del kwargs["name"]
129
+ if "loss_fn" in kwargs:
130
+ del kwargs["loss_fn"]
131
+ super().__init__(**kwargs)
132
+ self._validate_config()
133
+
134
+ def _set_config_defaults(self, config, config_defaults):
135
+ for k, v in config_defaults.items():
136
+ if k not in config:
137
+ config[k] = v
138
+ return config
139
+
140
+ def _validate_config(self):
141
+ self.attn_config = self._set_config_defaults(
142
+ self.attn_config, attn_config_defaults
143
+ )
144
+ self.init_config = self._set_config_defaults(
145
+ self.init_config, init_config_defaults
146
+ )
147
+ if self.d_model % self.n_heads != 0:
148
+ raise ValueError("d_model must be divisible by n_heads")
149
+ if any(
150
+ (
151
+ prob < 0 or prob > 1
152
+ for prob in [
153
+ self.attn_config["attn_pdrop"],
154
+ self.resid_pdrop,
155
+ self.emb_pdrop,
156
+ ]
157
+ )
158
+ ):
159
+ raise ValueError(
160
+ "self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1"
161
+ )
162
+ if self.attn_config["attn_impl"] not in ["torch", "flash", "triton"]:
163
+ raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
164
+ if self.attn_config["prefix_lm"] and self.attn_config["attn_impl"] not in [
165
+ "torch",
166
+ "triton",
167
+ ]:
168
+ raise NotImplementedError(
169
+ "prefix_lm only implemented with torch and triton attention."
170
+ )
171
+ if self.attn_config["alibi"] and self.attn_config["attn_impl"] not in [
172
+ "torch",
173
+ "triton",
174
+ ]:
175
+ raise NotImplementedError(
176
+ "alibi only implemented with torch and triton attention."
177
+ )
178
+ if self.attn_config["attn_uses_sequence_id"] and self.attn_config[
179
+ "attn_impl"
180
+ ] not in ["torch", "triton"]:
181
+ raise NotImplementedError(
182
+ "attn_uses_sequence_id only implemented with torch and triton attention."
183
+ )
184
+ if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
185
+ raise ValueError(
186
+ "model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!"
187
+ )
188
+ if isinstance(self.logit_scale, str) and self.logit_scale != "inv_sqrt_d_model":
189
+ raise ValueError(
190
+ f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
191
+ )
192
+ if self.init_config.get("name", None) is None:
193
+ raise ValueError(
194
+ f"self.init_config={self.init_config!r} 'name' needs to be set."
195
+ )
196
+ if not self.learned_pos_emb and (not self.attn_config["alibi"]):
197
+ raise ValueError(
198
+ f"Positional information must be provided to the model using either learned_pos_emb or alibi."
199
+ )
lisa_on_cuda/llava/model/language_model/mpt/custom_embedding.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch import Tensor
5
+
6
+
7
+ class SharedEmbedding(nn.Embedding):
8
+ def forward(self, input: Tensor, unembed: bool = False) -> Tensor:
9
+ if unembed:
10
+ return F.linear(input, self.weight)
11
+ return super().forward(input)
lisa_on_cuda/llava/model/language_model/mpt/flash_attn_triton.py ADDED
@@ -0,0 +1,1087 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from https://github.com/HazyResearch/flash-attention/blob/eff9fe6b8076df59d64d7a3f464696738a3c7c24/flash_attn/flash_attn_triton.py
3
+ update imports to use 'triton_pre_mlir'
4
+
5
+ *Experimental* implementation of FlashAttention in Triton.
6
+ Tested with triton==2.0.0.dev20221202.
7
+ Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions
8
+ other than 64:
9
+ https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207
10
+ We'll update this implementation with the new Triton backend once this is fixed.
11
+
12
+ We use the FlashAttention implementation from Phil Tillet a starting point.
13
+ https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
14
+
15
+ Changes:
16
+ - Implement both causal and non-causal attention.
17
+ - Implement both self-attention and cross-attention.
18
+ - Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
19
+ - Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
20
+ - Support attention bias.
21
+ - Speed up the forward pass a bit, and only store the LSE instead of m and l.
22
+ - Make the backward for d=128 much faster by reducing register spilling.
23
+ - Optionally parallelize the backward pass across seqlen_k, to deal with the case of
24
+ small batch size * nheads.
25
+
26
+ Caution:
27
+ - This is an *experimental* implementation. The forward pass should be quite robust but
28
+ I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
29
+ - This implementation has only been tested on A100.
30
+ - If you plan to use headdim other than 64 and 128, you should test for race conditions
31
+ (due to the Triton compiler), as done in tests/test_flash_attn.py
32
+ "test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
33
+ for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
34
+ that there are none left for other head dimensions.
35
+
36
+ Differences between this Triton version and the CUDA version:
37
+ - Triton version doesn't support dropout.
38
+ - Triton forward is generally faster than CUDA forward, while Triton backward is
39
+ generally slower than CUDA backward. Overall Triton forward + backward is slightly slower
40
+ than CUDA forward + backward.
41
+ - Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
42
+ - Triton version supports attention bias, while CUDA version doesn't.
43
+ """
44
+ import math
45
+
46
+ import torch
47
+ import triton_pre_mlir as triton
48
+ import triton_pre_mlir.language as tl
49
+
50
+
51
+ @triton.heuristics(
52
+ {
53
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
54
+ "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
55
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
56
+ }
57
+ )
58
+ @triton.jit
59
+ def _fwd_kernel(
60
+ Q,
61
+ K,
62
+ V,
63
+ Bias,
64
+ Out,
65
+ Lse,
66
+ TMP,
67
+ softmax_scale,
68
+ stride_qb,
69
+ stride_qh,
70
+ stride_qm,
71
+ stride_kb,
72
+ stride_kh,
73
+ stride_kn,
74
+ stride_vb,
75
+ stride_vh,
76
+ stride_vn,
77
+ stride_bb,
78
+ stride_bh,
79
+ stride_bm,
80
+ stride_ob,
81
+ stride_oh,
82
+ stride_om,
83
+ nheads,
84
+ seqlen_q,
85
+ seqlen_k,
86
+ seqlen_q_rounded,
87
+ headdim,
88
+ CACHE_KEY_SEQLEN_Q,
89
+ CACHE_KEY_SEQLEN_K,
90
+ BIAS_TYPE: tl.constexpr,
91
+ IS_CAUSAL: tl.constexpr,
92
+ BLOCK_HEADDIM: tl.constexpr,
93
+ EVEN_M: tl.constexpr,
94
+ EVEN_N: tl.constexpr,
95
+ EVEN_HEADDIM: tl.constexpr,
96
+ BLOCK_M: tl.constexpr,
97
+ BLOCK_N: tl.constexpr,
98
+ ):
99
+ start_m = tl.program_id(0)
100
+ off_hb = tl.program_id(1)
101
+ off_b = off_hb // nheads
102
+ off_h = off_hb % nheads
103
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
104
+ offs_n = tl.arange(0, BLOCK_N)
105
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
106
+ q_ptrs = (
107
+ Q
108
+ + off_b * stride_qb
109
+ + off_h * stride_qh
110
+ + (offs_m[:, None] * stride_qm + offs_d[None, :])
111
+ )
112
+ k_ptrs = (
113
+ K
114
+ + off_b * stride_kb
115
+ + off_h * stride_kh
116
+ + (offs_n[:, None] * stride_kn + offs_d[None, :])
117
+ )
118
+ v_ptrs = (
119
+ V
120
+ + off_b * stride_vb
121
+ + off_h * stride_vh
122
+ + (offs_n[:, None] * stride_vn + offs_d[None, :])
123
+ )
124
+ if BIAS_TYPE == "vector":
125
+ b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
126
+ elif BIAS_TYPE == "matrix":
127
+ b_ptrs = (
128
+ Bias
129
+ + off_b * stride_bb
130
+ + off_h * stride_bh
131
+ + (offs_m[:, None] * stride_bm + offs_n[None, :])
132
+ )
133
+ t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
134
+ lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
135
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
136
+ acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
137
+ if EVEN_M & EVEN_N:
138
+ if EVEN_HEADDIM:
139
+ q = tl.load(q_ptrs)
140
+ else:
141
+ q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
142
+ elif EVEN_HEADDIM:
143
+ q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
144
+ else:
145
+ q = tl.load(
146
+ q_ptrs,
147
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
148
+ other=0.0,
149
+ )
150
+ end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
151
+ for start_n in range(0, end_n, BLOCK_N):
152
+ start_n = tl.multiple_of(start_n, BLOCK_N)
153
+ if EVEN_N & EVEN_M:
154
+ if EVEN_HEADDIM:
155
+ k = tl.load(k_ptrs + start_n * stride_kn)
156
+ else:
157
+ k = tl.load(
158
+ k_ptrs + start_n * stride_kn,
159
+ mask=offs_d[None, :] < headdim,
160
+ other=0.0,
161
+ )
162
+ elif EVEN_HEADDIM:
163
+ k = tl.load(
164
+ k_ptrs + start_n * stride_kn,
165
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
166
+ other=0.0,
167
+ )
168
+ else:
169
+ k = tl.load(
170
+ k_ptrs + start_n * stride_kn,
171
+ mask=((start_n + offs_n)[:, None] < seqlen_k)
172
+ & (offs_d[None, :] < headdim),
173
+ other=0.0,
174
+ )
175
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
176
+ qk += tl.dot(q, k, trans_b=True)
177
+ if not EVEN_N:
178
+ qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
179
+ if IS_CAUSAL:
180
+ qk += tl.where(
181
+ offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")
182
+ )
183
+ if BIAS_TYPE != "none":
184
+ if BIAS_TYPE == "vector":
185
+ if EVEN_N:
186
+ bias = tl.load(b_ptrs + start_n).to(tl.float32)
187
+ else:
188
+ bias = tl.load(
189
+ b_ptrs + start_n, mask=start_n + offs_n < seqlen_k, other=0.0
190
+ ).to(tl.float32)
191
+ bias = bias[None, :]
192
+ elif BIAS_TYPE == "matrix":
193
+ if EVEN_M & EVEN_N:
194
+ bias = tl.load(b_ptrs + start_n).to(tl.float32)
195
+ else:
196
+ bias = tl.load(
197
+ b_ptrs + start_n,
198
+ mask=(offs_m[:, None] < seqlen_q)
199
+ & ((start_n + offs_n)[None, :] < seqlen_k),
200
+ other=0.0,
201
+ ).to(tl.float32)
202
+ qk = qk * softmax_scale + bias
203
+ m_ij = tl.maximum(tl.max(qk, 1), lse_i)
204
+ p = tl.exp(qk - m_ij[:, None])
205
+ else:
206
+ m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
207
+ p = tl.exp(qk * softmax_scale - m_ij[:, None])
208
+ l_ij = tl.sum(p, 1)
209
+ acc_o_scale = tl.exp(m_i - m_ij)
210
+ tl.store(t_ptrs, acc_o_scale)
211
+ acc_o_scale = tl.load(t_ptrs)
212
+ acc_o = acc_o * acc_o_scale[:, None]
213
+ if EVEN_N & EVEN_M:
214
+ if EVEN_HEADDIM:
215
+ v = tl.load(v_ptrs + start_n * stride_vn)
216
+ else:
217
+ v = tl.load(
218
+ v_ptrs + start_n * stride_vn,
219
+ mask=offs_d[None, :] < headdim,
220
+ other=0.0,
221
+ )
222
+ elif EVEN_HEADDIM:
223
+ v = tl.load(
224
+ v_ptrs + start_n * stride_vn,
225
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
226
+ other=0.0,
227
+ )
228
+ else:
229
+ v = tl.load(
230
+ v_ptrs + start_n * stride_vn,
231
+ mask=((start_n + offs_n)[:, None] < seqlen_k)
232
+ & (offs_d[None, :] < headdim),
233
+ other=0.0,
234
+ )
235
+ p = p.to(v.dtype)
236
+ acc_o += tl.dot(p, v)
237
+ m_i = m_ij
238
+ l_i_new = tl.exp(lse_i - m_ij) + l_ij
239
+ lse_i = m_ij + tl.log(l_i_new)
240
+ o_scale = tl.exp(m_i - lse_i)
241
+ tl.store(t_ptrs, o_scale)
242
+ o_scale = tl.load(t_ptrs)
243
+ acc_o = acc_o * o_scale[:, None]
244
+ start_m = tl.program_id(0)
245
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
246
+ lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
247
+ tl.store(lse_ptrs, lse_i)
248
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
249
+ out_ptrs = (
250
+ Out
251
+ + off_b * stride_ob
252
+ + off_h * stride_oh
253
+ + (offs_m[:, None] * stride_om + offs_d[None, :])
254
+ )
255
+ if EVEN_M:
256
+ if EVEN_HEADDIM:
257
+ tl.store(out_ptrs, acc_o)
258
+ else:
259
+ tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
260
+ elif EVEN_HEADDIM:
261
+ tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
262
+ else:
263
+ tl.store(
264
+ out_ptrs,
265
+ acc_o,
266
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
267
+ )
268
+
269
+
270
+ @triton.jit
271
+ def _bwd_preprocess_do_o_dot(
272
+ Out,
273
+ DO,
274
+ Delta,
275
+ stride_ob,
276
+ stride_oh,
277
+ stride_om,
278
+ stride_dob,
279
+ stride_doh,
280
+ stride_dom,
281
+ nheads,
282
+ seqlen_q,
283
+ seqlen_q_rounded,
284
+ headdim,
285
+ BLOCK_M: tl.constexpr,
286
+ BLOCK_HEADDIM: tl.constexpr,
287
+ ):
288
+ start_m = tl.program_id(0)
289
+ off_hb = tl.program_id(1)
290
+ off_b = off_hb // nheads
291
+ off_h = off_hb % nheads
292
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
293
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
294
+ o = tl.load(
295
+ Out
296
+ + off_b * stride_ob
297
+ + off_h * stride_oh
298
+ + offs_m[:, None] * stride_om
299
+ + offs_d[None, :],
300
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
301
+ other=0.0,
302
+ ).to(tl.float32)
303
+ do = tl.load(
304
+ DO
305
+ + off_b * stride_dob
306
+ + off_h * stride_doh
307
+ + offs_m[:, None] * stride_dom
308
+ + offs_d[None, :],
309
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
310
+ other=0.0,
311
+ ).to(tl.float32)
312
+ delta = tl.sum(o * do, axis=1)
313
+ tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
314
+
315
+
316
+ @triton.jit
317
+ def _bwd_store_dk_dv(
318
+ dk_ptrs,
319
+ dv_ptrs,
320
+ dk,
321
+ dv,
322
+ offs_n,
323
+ offs_d,
324
+ seqlen_k,
325
+ headdim,
326
+ EVEN_M: tl.constexpr,
327
+ EVEN_N: tl.constexpr,
328
+ EVEN_HEADDIM: tl.constexpr,
329
+ ):
330
+ if EVEN_N & EVEN_M:
331
+ if EVEN_HEADDIM:
332
+ tl.store(dv_ptrs, dv)
333
+ tl.store(dk_ptrs, dk)
334
+ else:
335
+ tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
336
+ tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
337
+ elif EVEN_HEADDIM:
338
+ tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
339
+ tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
340
+ else:
341
+ tl.store(
342
+ dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)
343
+ )
344
+ tl.store(
345
+ dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)
346
+ )
347
+
348
+
349
+ @triton.jit
350
+ def _bwd_kernel_one_col_block(
351
+ start_n,
352
+ Q,
353
+ K,
354
+ V,
355
+ Bias,
356
+ DO,
357
+ DQ,
358
+ DK,
359
+ DV,
360
+ LSE,
361
+ D,
362
+ softmax_scale,
363
+ stride_qm,
364
+ stride_kn,
365
+ stride_vn,
366
+ stride_bm,
367
+ stride_dom,
368
+ stride_dqm,
369
+ stride_dkn,
370
+ stride_dvn,
371
+ seqlen_q,
372
+ seqlen_k,
373
+ headdim,
374
+ ATOMIC_ADD: tl.constexpr,
375
+ BIAS_TYPE: tl.constexpr,
376
+ IS_CAUSAL: tl.constexpr,
377
+ BLOCK_HEADDIM: tl.constexpr,
378
+ EVEN_M: tl.constexpr,
379
+ EVEN_N: tl.constexpr,
380
+ EVEN_HEADDIM: tl.constexpr,
381
+ BLOCK_M: tl.constexpr,
382
+ BLOCK_N: tl.constexpr,
383
+ ):
384
+ begin_m = 0 if not IS_CAUSAL else start_n * BLOCK_N // BLOCK_M * BLOCK_M
385
+ offs_qm = begin_m + tl.arange(0, BLOCK_M)
386
+ offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
387
+ offs_m = tl.arange(0, BLOCK_M)
388
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
389
+ q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
390
+ k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
391
+ v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
392
+ do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
393
+ dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
394
+ if BIAS_TYPE == "vector":
395
+ b_ptrs = Bias + offs_n
396
+ elif BIAS_TYPE == "matrix":
397
+ b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
398
+ dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
399
+ dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
400
+ if begin_m >= seqlen_q:
401
+ dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
402
+ dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
403
+ _bwd_store_dk_dv(
404
+ dk_ptrs,
405
+ dv_ptrs,
406
+ dk,
407
+ dv,
408
+ offs_n,
409
+ offs_d,
410
+ seqlen_k,
411
+ headdim,
412
+ EVEN_M=EVEN_M,
413
+ EVEN_N=EVEN_N,
414
+ EVEN_HEADDIM=EVEN_HEADDIM,
415
+ )
416
+ return
417
+ if EVEN_N & EVEN_M:
418
+ if EVEN_HEADDIM:
419
+ k = tl.load(k_ptrs)
420
+ v = tl.load(v_ptrs)
421
+ else:
422
+ k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
423
+ v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
424
+ elif EVEN_HEADDIM:
425
+ k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
426
+ v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
427
+ else:
428
+ k = tl.load(
429
+ k_ptrs,
430
+ mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
431
+ other=0.0,
432
+ )
433
+ v = tl.load(
434
+ v_ptrs,
435
+ mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
436
+ other=0.0,
437
+ )
438
+ num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
439
+ for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
440
+ start_m = tl.multiple_of(start_m, BLOCK_M)
441
+ offs_m_curr = start_m + offs_m
442
+ if EVEN_M & EVEN_HEADDIM:
443
+ q = tl.load(q_ptrs)
444
+ elif EVEN_HEADDIM:
445
+ q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
446
+ else:
447
+ q = tl.load(
448
+ q_ptrs,
449
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
450
+ other=0.0,
451
+ )
452
+ qk = tl.dot(q, k, trans_b=True)
453
+ if not EVEN_N:
454
+ qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
455
+ if IS_CAUSAL:
456
+ qk = tl.where(offs_m_curr[:, None] >= offs_n[None, :], qk, float("-inf"))
457
+ if BIAS_TYPE != "none":
458
+ tl.debug_barrier()
459
+ if BIAS_TYPE == "vector":
460
+ if EVEN_N:
461
+ bias = tl.load(b_ptrs).to(tl.float32)
462
+ else:
463
+ bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(
464
+ tl.float32
465
+ )
466
+ bias = bias[None, :]
467
+ elif BIAS_TYPE == "matrix":
468
+ if EVEN_M & EVEN_N:
469
+ bias = tl.load(b_ptrs).to(tl.float32)
470
+ else:
471
+ bias = tl.load(
472
+ b_ptrs,
473
+ mask=(offs_m_curr[:, None] < seqlen_q)
474
+ & (offs_n[None, :] < seqlen_k),
475
+ other=0.0,
476
+ ).to(tl.float32)
477
+ qk = qk * softmax_scale + bias
478
+ if not EVEN_M & EVEN_HEADDIM:
479
+ tl.debug_barrier()
480
+ lse_i = tl.load(LSE + offs_m_curr)
481
+ if BIAS_TYPE == "none":
482
+ p = tl.exp(qk * softmax_scale - lse_i[:, None])
483
+ else:
484
+ p = tl.exp(qk - lse_i[:, None])
485
+ if EVEN_M & EVEN_HEADDIM:
486
+ do = tl.load(do_ptrs)
487
+ else:
488
+ do = tl.load(
489
+ do_ptrs,
490
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
491
+ other=0.0,
492
+ )
493
+ dv += tl.dot(p.to(do.dtype), do, trans_a=True)
494
+ if not EVEN_M & EVEN_HEADDIM:
495
+ tl.debug_barrier()
496
+ dp = tl.dot(do, v, trans_b=True)
497
+ if not EVEN_HEADDIM:
498
+ tl.debug_barrier()
499
+ Di = tl.load(D + offs_m_curr)
500
+ ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
501
+ dk += tl.dot(ds, q, trans_a=True)
502
+ if not EVEN_M & EVEN_HEADDIM:
503
+ tl.debug_barrier()
504
+ if not ATOMIC_ADD:
505
+ if EVEN_M & EVEN_HEADDIM:
506
+ dq = tl.load(dq_ptrs, eviction_policy="evict_last")
507
+ dq += tl.dot(ds, k)
508
+ tl.store(dq_ptrs, dq, eviction_policy="evict_last")
509
+ elif EVEN_HEADDIM:
510
+ dq = tl.load(
511
+ dq_ptrs,
512
+ mask=offs_m_curr[:, None] < seqlen_q,
513
+ other=0.0,
514
+ eviction_policy="evict_last",
515
+ )
516
+ dq += tl.dot(ds, k)
517
+ tl.store(
518
+ dq_ptrs,
519
+ dq,
520
+ mask=offs_m_curr[:, None] < seqlen_q,
521
+ eviction_policy="evict_last",
522
+ )
523
+ else:
524
+ dq = tl.load(
525
+ dq_ptrs,
526
+ mask=(offs_m_curr[:, None] < seqlen_q)
527
+ & (offs_d[None, :] < headdim),
528
+ other=0.0,
529
+ eviction_policy="evict_last",
530
+ )
531
+ dq += tl.dot(ds, k)
532
+ tl.store(
533
+ dq_ptrs,
534
+ dq,
535
+ mask=(offs_m_curr[:, None] < seqlen_q)
536
+ & (offs_d[None, :] < headdim),
537
+ eviction_policy="evict_last",
538
+ )
539
+ else:
540
+ dq = tl.dot(ds, k)
541
+ if EVEN_M & EVEN_HEADDIM:
542
+ tl.atomic_add(dq_ptrs, dq)
543
+ elif EVEN_HEADDIM:
544
+ tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
545
+ else:
546
+ tl.atomic_add(
547
+ dq_ptrs,
548
+ dq,
549
+ mask=(offs_m_curr[:, None] < seqlen_q)
550
+ & (offs_d[None, :] < headdim),
551
+ )
552
+ dq_ptrs += BLOCK_M * stride_dqm
553
+ q_ptrs += BLOCK_M * stride_qm
554
+ do_ptrs += BLOCK_M * stride_dom
555
+ if BIAS_TYPE == "matrix":
556
+ b_ptrs += BLOCK_M * stride_bm
557
+ dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
558
+ dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
559
+ _bwd_store_dk_dv(
560
+ dk_ptrs,
561
+ dv_ptrs,
562
+ dk,
563
+ dv,
564
+ offs_n,
565
+ offs_d,
566
+ seqlen_k,
567
+ headdim,
568
+ EVEN_M=EVEN_M,
569
+ EVEN_N=EVEN_N,
570
+ EVEN_HEADDIM=EVEN_HEADDIM,
571
+ )
572
+
573
+
574
+ def init_to_zero(name):
575
+ return lambda nargs: nargs[name].zero_()
576
+
577
+
578
+ @triton.autotune(
579
+ configs=[
580
+ triton.Config(
581
+ {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False},
582
+ num_warps=8,
583
+ num_stages=1,
584
+ pre_hook=init_to_zero("DQ"),
585
+ ),
586
+ triton.Config(
587
+ {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True},
588
+ num_warps=8,
589
+ num_stages=1,
590
+ pre_hook=init_to_zero("DQ"),
591
+ ),
592
+ ],
593
+ key=[
594
+ "CACHE_KEY_SEQLEN_Q",
595
+ "CACHE_KEY_SEQLEN_K",
596
+ "BIAS_TYPE",
597
+ "IS_CAUSAL",
598
+ "BLOCK_HEADDIM",
599
+ ],
600
+ )
601
+ @triton.heuristics(
602
+ {
603
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
604
+ "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
605
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
606
+ }
607
+ )
608
+ @triton.jit
609
+ def _bwd_kernel(
610
+ Q,
611
+ K,
612
+ V,
613
+ Bias,
614
+ DO,
615
+ DQ,
616
+ DK,
617
+ DV,
618
+ LSE,
619
+ D,
620
+ softmax_scale,
621
+ stride_qb,
622
+ stride_qh,
623
+ stride_qm,
624
+ stride_kb,
625
+ stride_kh,
626
+ stride_kn,
627
+ stride_vb,
628
+ stride_vh,
629
+ stride_vn,
630
+ stride_bb,
631
+ stride_bh,
632
+ stride_bm,
633
+ stride_dob,
634
+ stride_doh,
635
+ stride_dom,
636
+ stride_dqb,
637
+ stride_dqh,
638
+ stride_dqm,
639
+ stride_dkb,
640
+ stride_dkh,
641
+ stride_dkn,
642
+ stride_dvb,
643
+ stride_dvh,
644
+ stride_dvn,
645
+ nheads,
646
+ seqlen_q,
647
+ seqlen_k,
648
+ seqlen_q_rounded,
649
+ headdim,
650
+ CACHE_KEY_SEQLEN_Q,
651
+ CACHE_KEY_SEQLEN_K,
652
+ BIAS_TYPE: tl.constexpr,
653
+ IS_CAUSAL: tl.constexpr,
654
+ BLOCK_HEADDIM: tl.constexpr,
655
+ SEQUENCE_PARALLEL: tl.constexpr,
656
+ EVEN_M: tl.constexpr,
657
+ EVEN_N: tl.constexpr,
658
+ EVEN_HEADDIM: tl.constexpr,
659
+ BLOCK_M: tl.constexpr,
660
+ BLOCK_N: tl.constexpr,
661
+ ):
662
+ off_hb = tl.program_id(1)
663
+ off_b = off_hb // nheads
664
+ off_h = off_hb % nheads
665
+ Q += off_b * stride_qb + off_h * stride_qh
666
+ K += off_b * stride_kb + off_h * stride_kh
667
+ V += off_b * stride_vb + off_h * stride_vh
668
+ DO += off_b * stride_dob + off_h * stride_doh
669
+ DQ += off_b * stride_dqb + off_h * stride_dqh
670
+ DK += off_b * stride_dkb + off_h * stride_dkh
671
+ DV += off_b * stride_dvb + off_h * stride_dvh
672
+ if BIAS_TYPE != "none":
673
+ Bias += off_b * stride_bb + off_h * stride_bh
674
+ D += off_hb * seqlen_q_rounded
675
+ LSE += off_hb * seqlen_q_rounded
676
+ if not SEQUENCE_PARALLEL:
677
+ num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
678
+ for start_n in range(0, num_block_n):
679
+ _bwd_kernel_one_col_block(
680
+ start_n,
681
+ Q,
682
+ K,
683
+ V,
684
+ Bias,
685
+ DO,
686
+ DQ,
687
+ DK,
688
+ DV,
689
+ LSE,
690
+ D,
691
+ softmax_scale,
692
+ stride_qm,
693
+ stride_kn,
694
+ stride_vn,
695
+ stride_bm,
696
+ stride_dom,
697
+ stride_dqm,
698
+ stride_dkn,
699
+ stride_dvn,
700
+ seqlen_q,
701
+ seqlen_k,
702
+ headdim,
703
+ ATOMIC_ADD=False,
704
+ BIAS_TYPE=BIAS_TYPE,
705
+ IS_CAUSAL=IS_CAUSAL,
706
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
707
+ EVEN_M=EVEN_M,
708
+ EVEN_N=EVEN_N,
709
+ EVEN_HEADDIM=EVEN_HEADDIM,
710
+ BLOCK_M=BLOCK_M,
711
+ BLOCK_N=BLOCK_N,
712
+ )
713
+ else:
714
+ start_n = tl.program_id(0)
715
+ _bwd_kernel_one_col_block(
716
+ start_n,
717
+ Q,
718
+ K,
719
+ V,
720
+ Bias,
721
+ DO,
722
+ DQ,
723
+ DK,
724
+ DV,
725
+ LSE,
726
+ D,
727
+ softmax_scale,
728
+ stride_qm,
729
+ stride_kn,
730
+ stride_vn,
731
+ stride_bm,
732
+ stride_dom,
733
+ stride_dqm,
734
+ stride_dkn,
735
+ stride_dvn,
736
+ seqlen_q,
737
+ seqlen_k,
738
+ headdim,
739
+ ATOMIC_ADD=True,
740
+ BIAS_TYPE=BIAS_TYPE,
741
+ IS_CAUSAL=IS_CAUSAL,
742
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
743
+ EVEN_M=EVEN_M,
744
+ EVEN_N=EVEN_N,
745
+ EVEN_HEADDIM=EVEN_HEADDIM,
746
+ BLOCK_M=BLOCK_M,
747
+ BLOCK_N=BLOCK_N,
748
+ )
749
+
750
+
751
+ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
752
+ (batch, seqlen_q, nheads, d) = q.shape
753
+ (_, seqlen_k, _, _) = k.shape
754
+ assert k.shape == (batch, seqlen_k, nheads, d)
755
+ assert v.shape == (batch, seqlen_k, nheads, d)
756
+ assert d <= 128, "FlashAttention only support head dimensions up to 128"
757
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
758
+ assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
759
+ assert q.is_cuda and k.is_cuda and v.is_cuda
760
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
761
+ has_bias = bias is not None
762
+ bias_type = "none"
763
+ if has_bias:
764
+ assert bias.dtype in [q.dtype, torch.float]
765
+ assert bias.is_cuda
766
+ assert bias.dim() == 4
767
+ if bias.stride(-1) != 1:
768
+ bias = bias.contiguous()
769
+ if bias.shape[2:] == (1, seqlen_k):
770
+ bias_type = "vector"
771
+ elif bias.shape[2:] == (seqlen_q, seqlen_k):
772
+ bias_type = "matrix"
773
+ else:
774
+ raise RuntimeError(
775
+ "Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)"
776
+ )
777
+ bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
778
+ bias_strides = (
779
+ (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
780
+ )
781
+ seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
782
+ lse = torch.empty(
783
+ (batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32
784
+ )
785
+ tmp = torch.empty(
786
+ (batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32
787
+ )
788
+ o = torch.empty_like(q)
789
+ BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
790
+ BLOCK = 128
791
+ num_warps = 4 if d <= 64 else 8
792
+ grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
793
+ _fwd_kernel[grid](
794
+ q,
795
+ k,
796
+ v,
797
+ bias,
798
+ o,
799
+ lse,
800
+ tmp,
801
+ softmax_scale,
802
+ q.stride(0),
803
+ q.stride(2),
804
+ q.stride(1),
805
+ k.stride(0),
806
+ k.stride(2),
807
+ k.stride(1),
808
+ v.stride(0),
809
+ v.stride(2),
810
+ v.stride(1),
811
+ *bias_strides,
812
+ o.stride(0),
813
+ o.stride(2),
814
+ o.stride(1),
815
+ nheads,
816
+ seqlen_q,
817
+ seqlen_k,
818
+ seqlen_q_rounded,
819
+ d,
820
+ seqlen_q // 32,
821
+ seqlen_k // 32,
822
+ bias_type,
823
+ causal,
824
+ BLOCK_HEADDIM,
825
+ BLOCK_M=BLOCK,
826
+ BLOCK_N=BLOCK,
827
+ num_warps=num_warps,
828
+ num_stages=1
829
+ )
830
+ return (o, lse, softmax_scale)
831
+
832
+
833
+ def _flash_attn_backward(
834
+ do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None
835
+ ):
836
+ if do.stride(-1) != 1:
837
+ do = do.contiguous()
838
+ (batch, seqlen_q, nheads, d) = q.shape
839
+ (_, seqlen_k, _, _) = k.shape
840
+ assert d <= 128
841
+ seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
842
+ assert lse.shape == (batch, nheads, seqlen_q_rounded)
843
+ assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
844
+ assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
845
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
846
+ dq_accum = torch.empty_like(q, dtype=torch.float32)
847
+ delta = torch.empty_like(lse)
848
+ BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
849
+ grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
850
+ _bwd_preprocess_do_o_dot[grid](
851
+ o,
852
+ do,
853
+ delta,
854
+ o.stride(0),
855
+ o.stride(2),
856
+ o.stride(1),
857
+ do.stride(0),
858
+ do.stride(2),
859
+ do.stride(1),
860
+ nheads,
861
+ seqlen_q,
862
+ seqlen_q_rounded,
863
+ d,
864
+ BLOCK_M=128,
865
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
866
+ )
867
+ has_bias = bias is not None
868
+ bias_type = "none"
869
+ if has_bias:
870
+ assert bias.dtype in [q.dtype, torch.float]
871
+ assert bias.is_cuda
872
+ assert bias.dim() == 4
873
+ assert bias.stride(-1) == 1
874
+ if bias.shape[2:] == (1, seqlen_k):
875
+ bias_type = "vector"
876
+ elif bias.shape[2:] == (seqlen_q, seqlen_k):
877
+ bias_type = "matrix"
878
+ else:
879
+ raise RuntimeError(
880
+ "Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)"
881
+ )
882
+ bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
883
+ bias_strides = (
884
+ (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
885
+ )
886
+ grid = lambda META: (
887
+ triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1,
888
+ batch * nheads,
889
+ )
890
+ _bwd_kernel[grid](
891
+ q,
892
+ k,
893
+ v,
894
+ bias,
895
+ do,
896
+ dq_accum,
897
+ dk,
898
+ dv,
899
+ lse,
900
+ delta,
901
+ softmax_scale,
902
+ q.stride(0),
903
+ q.stride(2),
904
+ q.stride(1),
905
+ k.stride(0),
906
+ k.stride(2),
907
+ k.stride(1),
908
+ v.stride(0),
909
+ v.stride(2),
910
+ v.stride(1),
911
+ *bias_strides,
912
+ do.stride(0),
913
+ do.stride(2),
914
+ do.stride(1),
915
+ dq_accum.stride(0),
916
+ dq_accum.stride(2),
917
+ dq_accum.stride(1),
918
+ dk.stride(0),
919
+ dk.stride(2),
920
+ dk.stride(1),
921
+ dv.stride(0),
922
+ dv.stride(2),
923
+ dv.stride(1),
924
+ nheads,
925
+ seqlen_q,
926
+ seqlen_k,
927
+ seqlen_q_rounded,
928
+ d,
929
+ seqlen_q // 32,
930
+ seqlen_k // 32,
931
+ bias_type,
932
+ causal,
933
+ BLOCK_HEADDIM
934
+ )
935
+ dq.copy_(dq_accum)
936
+
937
+
938
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
939
+ @staticmethod
940
+ def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
941
+ """
942
+ qkv: (batch, seqlen, 3, nheads, headdim)
943
+ bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
944
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
945
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
946
+ """
947
+ if qkv.stride(-1) != 1:
948
+ qkv = qkv.contiguous()
949
+ (o, lse, ctx.softmax_scale) = _flash_attn_forward(
950
+ qkv[:, :, 0],
951
+ qkv[:, :, 1],
952
+ qkv[:, :, 2],
953
+ bias=bias,
954
+ causal=causal,
955
+ softmax_scale=softmax_scale,
956
+ )
957
+ ctx.save_for_backward(qkv, o, lse, bias)
958
+ ctx.causal = causal
959
+ return o
960
+
961
+ @staticmethod
962
+ def backward(ctx, do):
963
+ (qkv, o, lse, bias) = ctx.saved_tensors
964
+ assert not ctx.needs_input_grad[
965
+ 1
966
+ ], "FlashAttention does not support bias gradient yet"
967
+ with torch.inference_mode():
968
+ dqkv = torch.empty_like(qkv)
969
+ _flash_attn_backward(
970
+ do,
971
+ qkv[:, :, 0],
972
+ qkv[:, :, 1],
973
+ qkv[:, :, 2],
974
+ o,
975
+ lse,
976
+ dqkv[:, :, 0],
977
+ dqkv[:, :, 1],
978
+ dqkv[:, :, 2],
979
+ bias=bias,
980
+ causal=ctx.causal,
981
+ softmax_scale=ctx.softmax_scale,
982
+ )
983
+ return (dqkv, None, None, None)
984
+
985
+
986
+ flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
987
+
988
+
989
+ class FlashAttnKVPackedFunc(torch.autograd.Function):
990
+ @staticmethod
991
+ def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
992
+ """
993
+ q: (batch, seqlen_q, nheads, headdim)
994
+ kv: (batch, seqlen_k, 2, nheads, headdim)
995
+ bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
996
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
997
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
998
+ """
999
+ (q, kv) = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]]
1000
+ (o, lse, ctx.softmax_scale) = _flash_attn_forward(
1001
+ q,
1002
+ kv[:, :, 0],
1003
+ kv[:, :, 1],
1004
+ bias=bias,
1005
+ causal=causal,
1006
+ softmax_scale=softmax_scale,
1007
+ )
1008
+ ctx.save_for_backward(q, kv, o, lse, bias)
1009
+ ctx.causal = causal
1010
+ return o
1011
+
1012
+ @staticmethod
1013
+ def backward(ctx, do):
1014
+ (q, kv, o, lse, bias) = ctx.saved_tensors
1015
+ if len(ctx.needs_input_grad) >= 3:
1016
+ assert not ctx.needs_input_grad[
1017
+ 2
1018
+ ], "FlashAttention does not support bias gradient yet"
1019
+ with torch.inference_mode():
1020
+ dq = torch.empty_like(q)
1021
+ dkv = torch.empty_like(kv)
1022
+ _flash_attn_backward(
1023
+ do,
1024
+ q,
1025
+ kv[:, :, 0],
1026
+ kv[:, :, 1],
1027
+ o,
1028
+ lse,
1029
+ dq,
1030
+ dkv[:, :, 0],
1031
+ dkv[:, :, 1],
1032
+ bias=bias,
1033
+ causal=ctx.causal,
1034
+ softmax_scale=ctx.softmax_scale,
1035
+ )
1036
+ return (dq, dkv, None, None, None)
1037
+
1038
+
1039
+ flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
1040
+
1041
+
1042
+ class FlashAttnFunc(torch.autograd.Function):
1043
+ @staticmethod
1044
+ def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
1045
+ """
1046
+ q: (batch_size, seqlen_q, nheads, headdim)
1047
+ k, v: (batch_size, seqlen_k, nheads, headdim)
1048
+ bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
1049
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
1050
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
1051
+ """
1052
+ (q, k, v) = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
1053
+ (o, lse, ctx.softmax_scale) = _flash_attn_forward(
1054
+ q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale
1055
+ )
1056
+ ctx.save_for_backward(q, k, v, o, lse, bias)
1057
+ ctx.causal = causal
1058
+ return o
1059
+
1060
+ @staticmethod
1061
+ def backward(ctx, do):
1062
+ (q, k, v, o, lse, bias) = ctx.saved_tensors
1063
+ assert not ctx.needs_input_grad[
1064
+ 3
1065
+ ], "FlashAttention does not support bias gradient yet"
1066
+ with torch.inference_mode():
1067
+ dq = torch.empty_like(q)
1068
+ dk = torch.empty_like(k)
1069
+ dv = torch.empty_like(v)
1070
+ _flash_attn_backward(
1071
+ do,
1072
+ q,
1073
+ k,
1074
+ v,
1075
+ o,
1076
+ lse,
1077
+ dq,
1078
+ dk,
1079
+ dv,
1080
+ bias=bias,
1081
+ causal=ctx.causal,
1082
+ softmax_scale=ctx.softmax_scale,
1083
+ )
1084
+ return (dq, dk, dv, None, None, None)
1085
+
1086
+
1087
+ flash_attn_func = FlashAttnFunc.apply
lisa_on_cuda/llava/model/language_model/mpt/hf_prefixlm_converter.py ADDED
@@ -0,0 +1,750 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Converts Huggingface Causal LM to Prefix LM.
2
+
3
+ Conversion does lightweight surgery on a HuggingFace
4
+ Causal LM to convert it to a Prefix LM.
5
+
6
+ Prefix LMs accepts a `bidirectional_mask` input in `forward`
7
+ and treat the input prompt as the prefix in `generate`.
8
+ """
9
+ import math
10
+ import warnings
11
+ from types import MethodType
12
+ from typing import Any, Dict, List, Optional, Tuple, Union
13
+
14
+ import torch
15
+ from transformers.models.bloom.modeling_bloom import (
16
+ BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel,
17
+ CausalLMOutputWithCrossAttentions, CrossEntropyLoss)
18
+ from transformers.models.bloom.modeling_bloom import \
19
+ _expand_mask as _expand_mask_bloom
20
+ from transformers.models.bloom.modeling_bloom import \
21
+ _make_causal_mask as _make_causal_mask_bloom
22
+ from transformers.models.bloom.modeling_bloom import logging
23
+ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
24
+ from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
25
+ from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
26
+ from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
27
+ from transformers.models.opt.modeling_opt import OPTForCausalLM
28
+ from transformers.models.opt.modeling_opt import \
29
+ _expand_mask as _expand_mask_opt
30
+ from transformers.models.opt.modeling_opt import \
31
+ _make_causal_mask as _make_causal_mask_opt
32
+
33
+ logger = logging.get_logger(__name__)
34
+ _SUPPORTED_GPT_MODELS = (
35
+ GPT2LMHeadModel,
36
+ GPTJForCausalLM,
37
+ GPTNeoForCausalLM,
38
+ GPTNeoXForCausalLM,
39
+ )
40
+ CAUSAL_GPT_TYPES = Union[
41
+ GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM
42
+ ]
43
+
44
+
45
+ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:
46
+ """Converts a GPT-style Causal LM to a Prefix LM.
47
+
48
+ Supported HuggingFace model classes:
49
+ - `GPT2LMHeadModel`
50
+ - `GPTNeoForCausalLM`
51
+ - `GPTNeoXForCausalLM`
52
+ - `GPTJForCausalLM`
53
+
54
+ See `convert_hf_causal_lm_to_prefix_lm` for more details.
55
+ """
56
+ if hasattr(model, "_prefix_lm_converted"):
57
+ return model
58
+ assert isinstance(model, _SUPPORTED_GPT_MODELS)
59
+ assert (
60
+ model.config.add_cross_attention == False
61
+ ), "Only supports GPT-style decoder-only models"
62
+
63
+ def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:
64
+ """Helper that gets a list of the model's attention modules.
65
+
66
+ Each module has a `bias` buffer used for causal masking. The Prefix LM
67
+ conversion adds logic to dynamically manipulate these biases to support
68
+ Prefix LM attention masking.
69
+ """
70
+ attn_modules = []
71
+ if isinstance(model, GPTNeoXForCausalLM):
72
+ blocks = model.gpt_neox.layers
73
+ else:
74
+ blocks = model.transformer.h
75
+ for block in blocks:
76
+ if isinstance(model, GPTNeoForCausalLM):
77
+ if block.attn.attention_type != "global":
78
+ continue
79
+ attn_module = block.attn.attention
80
+ elif isinstance(model, GPTNeoXForCausalLM):
81
+ attn_module = block.attention
82
+ else:
83
+ attn_module = block.attn
84
+ attn_modules.append(attn_module)
85
+ return attn_modules
86
+
87
+ setattr(model, "_original_forward", getattr(model, "forward"))
88
+ setattr(model, "_original_generate", getattr(model, "generate"))
89
+
90
+ def forward(
91
+ self: CAUSAL_GPT_TYPES,
92
+ input_ids: Optional[torch.LongTensor] = None,
93
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
94
+ attention_mask: Optional[torch.FloatTensor] = None,
95
+ bidirectional_mask: Optional[torch.Tensor] = None,
96
+ token_type_ids: Optional[torch.LongTensor] = None,
97
+ position_ids: Optional[torch.LongTensor] = None,
98
+ head_mask: Optional[torch.FloatTensor] = None,
99
+ inputs_embeds: Optional[torch.FloatTensor] = None,
100
+ labels: Optional[torch.LongTensor] = None,
101
+ use_cache: Optional[bool] = None,
102
+ output_attentions: Optional[bool] = None,
103
+ output_hidden_states: Optional[bool] = None,
104
+ return_dict: Optional[bool] = None,
105
+ ):
106
+ """Wraps original forward to enable PrefixLM attention."""
107
+
108
+ def call_og_forward():
109
+ if isinstance(self, GPTNeoXForCausalLM):
110
+ return self._original_forward(
111
+ input_ids=input_ids,
112
+ past_key_values=past_key_values,
113
+ attention_mask=attention_mask,
114
+ head_mask=head_mask,
115
+ inputs_embeds=inputs_embeds,
116
+ labels=labels,
117
+ use_cache=use_cache,
118
+ output_attentions=output_attentions,
119
+ output_hidden_states=output_hidden_states,
120
+ return_dict=return_dict,
121
+ )
122
+ else:
123
+ return self._original_forward(
124
+ input_ids=input_ids,
125
+ past_key_values=past_key_values,
126
+ attention_mask=attention_mask,
127
+ token_type_ids=token_type_ids,
128
+ position_ids=position_ids,
129
+ head_mask=head_mask,
130
+ inputs_embeds=inputs_embeds,
131
+ labels=labels,
132
+ use_cache=use_cache,
133
+ output_attentions=output_attentions,
134
+ output_hidden_states=output_hidden_states,
135
+ return_dict=return_dict,
136
+ )
137
+
138
+ if bidirectional_mask is None:
139
+ return call_og_forward()
140
+ assert isinstance(bidirectional_mask, torch.Tensor)
141
+ attn_modules = _get_attn_modules(model)
142
+ (b, s) = bidirectional_mask.shape
143
+ max_length = attn_modules[0].bias.shape[-1]
144
+ if s > max_length:
145
+ raise ValueError(
146
+ f"bidirectional_mask sequence length (={s}) exceeds the "
147
+ + f"max length allowed by the model ({max_length})."
148
+ )
149
+ assert s <= max_length
150
+ if s < max_length:
151
+ pad = torch.zeros(
152
+ (int(b), int(max_length - s)),
153
+ dtype=bidirectional_mask.dtype,
154
+ device=bidirectional_mask.device,
155
+ )
156
+ bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
157
+ bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
158
+ for attn_module in attn_modules:
159
+ attn_module.bias.data = torch.logical_or(
160
+ attn_module.bias.data, bidirectional
161
+ )
162
+ output = call_og_forward()
163
+ for attn_module in attn_modules:
164
+ attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
165
+ return output
166
+
167
+ def generate(self: CAUSAL_GPT_TYPES, *args: tuple, **kwargs: Dict[str, Any]):
168
+ """Wraps original generate to enable PrefixLM attention."""
169
+ attn_modules = _get_attn_modules(model)
170
+ for attn_module in attn_modules:
171
+ attn_module.bias.data[:] = 1
172
+ output = self._original_generate(*args, **kwargs)
173
+ for attn_module in attn_modules:
174
+ attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
175
+ return output
176
+
177
+ setattr(model, "forward", MethodType(forward, model))
178
+ setattr(model, "generate", MethodType(generate, model))
179
+ setattr(model, "_prefix_lm_converted", True)
180
+ return model
181
+
182
+
183
+ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM:
184
+ """Converts a BLOOM Causal LM to a Prefix LM.
185
+
186
+ Supported HuggingFace model classes:
187
+ - `BloomForCausalLM`
188
+
189
+ See `convert_hf_causal_lm_to_prefix_lm` for more details.
190
+ """
191
+ if hasattr(model, "_prefix_lm_converted"):
192
+ return model
193
+ assert isinstance(model, BloomForCausalLM)
194
+ assert (
195
+ model.config.add_cross_attention == False
196
+ ), "Only supports BLOOM decoder-only models"
197
+
198
+ def _prepare_attn_mask(
199
+ self: BloomModel,
200
+ attention_mask: torch.Tensor,
201
+ bidirectional_mask: Optional[torch.Tensor],
202
+ input_shape: Tuple[int, int],
203
+ past_key_values_length: int,
204
+ ) -> torch.BoolTensor:
205
+ combined_attention_mask = None
206
+ device = attention_mask.device
207
+ (_, src_length) = input_shape
208
+ if src_length > 1:
209
+ combined_attention_mask = _make_causal_mask_bloom(
210
+ input_shape,
211
+ device=device,
212
+ past_key_values_length=past_key_values_length,
213
+ )
214
+ if bidirectional_mask is not None:
215
+ assert attention_mask.shape == bidirectional_mask.shape
216
+ expanded_bidirectional_mask = _expand_mask_bloom(
217
+ bidirectional_mask, tgt_length=src_length
218
+ )
219
+ combined_attention_mask = torch.logical_and(
220
+ combined_attention_mask, expanded_bidirectional_mask
221
+ )
222
+ expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length)
223
+ combined_attention_mask = (
224
+ expanded_attn_mask
225
+ if combined_attention_mask is None
226
+ else expanded_attn_mask | combined_attention_mask
227
+ )
228
+ return combined_attention_mask
229
+
230
+ def _build_alibi_tensor(
231
+ self: BloomModel,
232
+ batch_size: int,
233
+ query_length: int,
234
+ key_length: int,
235
+ dtype: torch.dtype,
236
+ device: torch.device,
237
+ ) -> torch.Tensor:
238
+ num_heads = self.config.n_head
239
+ closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
240
+ base = torch.tensor(
241
+ 2 ** (-(2 ** (-(math.log2(closest_power_of_2) - 3)))),
242
+ device=device,
243
+ dtype=torch.float32,
244
+ )
245
+ powers = torch.arange(
246
+ 1, 1 + closest_power_of_2, device=device, dtype=torch.int32
247
+ )
248
+ slopes = torch.pow(base, powers)
249
+ if closest_power_of_2 != num_heads:
250
+ extra_base = torch.tensor(
251
+ 2 ** (-(2 ** (-(math.log2(2 * closest_power_of_2) - 3)))),
252
+ device=device,
253
+ dtype=torch.float32,
254
+ )
255
+ num_remaining_heads = min(
256
+ closest_power_of_2, num_heads - closest_power_of_2
257
+ )
258
+ extra_powers = torch.arange(
259
+ 1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32
260
+ )
261
+ slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
262
+ qa = torch.arange(query_length, device=device, dtype=torch.int32).view(-1, 1)
263
+ ka = torch.arange(key_length, device=device, dtype=torch.int32).view(1, -1)
264
+ diffs = qa - ka + key_length - query_length
265
+ diffs = -diffs.abs()
266
+ alibi = slopes.view(1, num_heads, 1, 1) * diffs.view(
267
+ 1, 1, query_length, key_length
268
+ )
269
+ alibi = alibi.expand(batch_size, -1, -1, -1).reshape(
270
+ -1, query_length, key_length
271
+ )
272
+ return alibi.to(dtype)
273
+
274
+ KeyValueT = Tuple[torch.Tensor, torch.Tensor]
275
+
276
+ def forward(
277
+ self: BloomModel,
278
+ input_ids: Optional[torch.LongTensor] = None,
279
+ past_key_values: Optional[Tuple[KeyValueT, ...]] = None,
280
+ attention_mask: Optional[torch.Tensor] = None,
281
+ bidirectional_mask: Optional[torch.Tensor] = None,
282
+ head_mask: Optional[torch.LongTensor] = None,
283
+ inputs_embeds: Optional[torch.LongTensor] = None,
284
+ use_cache: Optional[bool] = None,
285
+ output_attentions: Optional[bool] = None,
286
+ output_hidden_states: Optional[bool] = None,
287
+ return_dict: Optional[bool] = None,
288
+ **deprecated_arguments,
289
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
290
+ if deprecated_arguments.pop("position_ids", False) is not False:
291
+ warnings.warn(
292
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. "
293
+ + "You can safely ignore passing `position_ids`.",
294
+ FutureWarning,
295
+ )
296
+ if len(deprecated_arguments) > 0:
297
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
298
+ output_attentions = (
299
+ output_attentions
300
+ if output_attentions is not None
301
+ else self.config.output_attentions
302
+ )
303
+ output_hidden_states = (
304
+ output_hidden_states
305
+ if output_hidden_states is not None
306
+ else self.config.output_hidden_states
307
+ )
308
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
309
+ return_dict = (
310
+ return_dict if return_dict is not None else self.config.use_return_dict
311
+ )
312
+ if input_ids is not None and inputs_embeds is not None:
313
+ raise ValueError(
314
+ "You cannot specify both input_ids and inputs_embeds at the same time"
315
+ )
316
+ elif input_ids is not None:
317
+ (batch_size, seq_length) = input_ids.shape
318
+ elif inputs_embeds is not None:
319
+ (batch_size, seq_length, _) = inputs_embeds.shape
320
+ else:
321
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
322
+ if past_key_values is None:
323
+ past_key_values = tuple([None] * len(self.h))
324
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
325
+ if inputs_embeds is None:
326
+ inputs_embeds = self.word_embeddings(input_ids)
327
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds)
328
+ presents = () if use_cache else None
329
+ all_self_attentions = () if output_attentions else None
330
+ all_hidden_states = () if output_hidden_states else None
331
+ seq_length_with_past = seq_length
332
+ past_key_values_length = 0
333
+ if past_key_values[0] is not None:
334
+ tmp = past_key_values[0][0]
335
+ past_key_values_length = tmp.shape[2]
336
+ seq_length_with_past = seq_length_with_past + past_key_values_length
337
+ if attention_mask is None:
338
+ attention_mask = torch.ones(
339
+ (batch_size, seq_length_with_past), device=hidden_states.device
340
+ )
341
+ else:
342
+ attention_mask = attention_mask.to(hidden_states.device)
343
+ alibi = self._build_alibi_tensor(
344
+ batch_size=batch_size,
345
+ query_length=seq_length,
346
+ key_length=seq_length_with_past,
347
+ dtype=hidden_states.dtype,
348
+ device=hidden_states.device,
349
+ )
350
+ causal_mask = self._prepare_attn_mask(
351
+ attention_mask,
352
+ bidirectional_mask,
353
+ input_shape=(batch_size, seq_length),
354
+ past_key_values_length=past_key_values_length,
355
+ )
356
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
357
+ if output_hidden_states:
358
+ hst = (hidden_states,)
359
+ all_hidden_states = all_hidden_states + hst
360
+ if self.gradient_checkpointing and self.training:
361
+ if use_cache:
362
+ logger.warning(
363
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
364
+ )
365
+ use_cache = False
366
+
367
+ def create_custom_forward(module):
368
+ def custom_forward(*inputs):
369
+ return module(
370
+ *inputs,
371
+ use_cache=use_cache,
372
+ output_attentions=output_attentions,
373
+ )
374
+
375
+ return custom_forward
376
+
377
+ outputs = torch.utils.checkpoint.checkpoint(
378
+ create_custom_forward(block),
379
+ hidden_states,
380
+ alibi,
381
+ causal_mask,
382
+ head_mask[i],
383
+ )
384
+ else:
385
+ outputs = block(
386
+ hidden_states,
387
+ layer_past=layer_past,
388
+ attention_mask=causal_mask,
389
+ head_mask=head_mask[i],
390
+ use_cache=use_cache,
391
+ output_attentions=output_attentions,
392
+ alibi=alibi,
393
+ )
394
+ hidden_states = outputs[0]
395
+ if use_cache is True:
396
+ presents = presents + (outputs[1],)
397
+ if output_attentions:
398
+ oa = (outputs[2 if use_cache else 1],)
399
+ all_self_attentions = all_self_attentions + oa
400
+ hidden_states = self.ln_f(hidden_states)
401
+ if output_hidden_states:
402
+ hst = (hidden_states,)
403
+ all_hidden_states = all_hidden_states + hst
404
+ if not return_dict:
405
+ return tuple(
406
+ (
407
+ v
408
+ for v in [
409
+ hidden_states,
410
+ presents,
411
+ all_hidden_states,
412
+ all_self_attentions,
413
+ ]
414
+ if v is not None
415
+ )
416
+ )
417
+ return BaseModelOutputWithPastAndCrossAttentions(
418
+ last_hidden_state=hidden_states,
419
+ past_key_values=presents,
420
+ hidden_states=all_hidden_states,
421
+ attentions=all_self_attentions,
422
+ )
423
+
424
+ setattr(
425
+ model.transformer,
426
+ "_prepare_attn_mask",
427
+ MethodType(_prepare_attn_mask, model.transformer),
428
+ )
429
+ setattr(
430
+ model.transformer,
431
+ "_build_alibi_tensor",
432
+ MethodType(_build_alibi_tensor, model.transformer),
433
+ )
434
+ setattr(model.transformer, "forward", MethodType(forward, model.transformer))
435
+ KeyValueT = Tuple[torch.Tensor, torch.Tensor]
436
+
437
+ def forward(
438
+ self: BloomForCausalLM,
439
+ input_ids: Optional[torch.LongTensor] = None,
440
+ past_key_values: Optional[Tuple[KeyValueT, ...]] = None,
441
+ attention_mask: Optional[torch.Tensor] = None,
442
+ bidirectional_mask: Optional[torch.Tensor] = None,
443
+ head_mask: Optional[torch.Tensor] = None,
444
+ inputs_embeds: Optional[torch.Tensor] = None,
445
+ labels: Optional[torch.Tensor] = None,
446
+ use_cache: Optional[bool] = None,
447
+ output_attentions: Optional[bool] = None,
448
+ output_hidden_states: Optional[bool] = None,
449
+ return_dict: Optional[bool] = None,
450
+ **deprecated_arguments,
451
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
452
+ """Replacement forward method for BloomCausalLM."""
453
+ if deprecated_arguments.pop("position_ids", False) is not False:
454
+ warnings.warn(
455
+ "`position_ids` have no functionality in BLOOM and will be removed "
456
+ + "in v5.0.0. You can safely ignore passing `position_ids`.",
457
+ FutureWarning,
458
+ )
459
+ if len(deprecated_arguments) > 0:
460
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
461
+ return_dict = (
462
+ return_dict if return_dict is not None else self.config.use_return_dict
463
+ )
464
+ transformer_outputs = self.transformer(
465
+ input_ids,
466
+ past_key_values=past_key_values,
467
+ attention_mask=attention_mask,
468
+ bidirectional_mask=bidirectional_mask,
469
+ head_mask=head_mask,
470
+ inputs_embeds=inputs_embeds,
471
+ use_cache=use_cache,
472
+ output_attentions=output_attentions,
473
+ output_hidden_states=output_hidden_states,
474
+ return_dict=return_dict,
475
+ )
476
+ hidden_states = transformer_outputs[0]
477
+ lm_logits = self.lm_head(hidden_states)
478
+ loss = None
479
+ if labels is not None:
480
+ shift_logits = lm_logits[..., :-1, :].contiguous()
481
+ shift_labels = labels[..., 1:].contiguous()
482
+ (batch_size, seq_length, vocab_size) = shift_logits.shape
483
+ loss_fct = CrossEntropyLoss()
484
+ loss = loss_fct(
485
+ shift_logits.view(batch_size * seq_length, vocab_size),
486
+ shift_labels.view(batch_size * seq_length),
487
+ )
488
+ if not return_dict:
489
+ output = (lm_logits,) + transformer_outputs[1:]
490
+ return (loss,) + output if loss is not None else output
491
+ return CausalLMOutputWithCrossAttentions(
492
+ loss=loss,
493
+ logits=lm_logits,
494
+ past_key_values=transformer_outputs.past_key_values,
495
+ hidden_states=transformer_outputs.hidden_states,
496
+ attentions=transformer_outputs.attentions,
497
+ )
498
+
499
+ def prepare_inputs_for_generation(
500
+ self: BloomForCausalLM,
501
+ input_ids: torch.LongTensor,
502
+ past: Optional[torch.Tensor] = None,
503
+ attention_mask: Optional[torch.Tensor] = None,
504
+ **kwargs,
505
+ ) -> dict:
506
+ if past:
507
+ input_ids = input_ids[:, -1].unsqueeze(-1)
508
+ bidirectional_mask = None
509
+ if past[0][0].shape[0] == input_ids.shape[0]:
510
+ past = self._convert_to_bloom_cache(past)
511
+ else:
512
+ bidirectional_mask = torch.ones_like(input_ids)
513
+ return {
514
+ "input_ids": input_ids,
515
+ "past_key_values": past,
516
+ "use_cache": True,
517
+ "attention_mask": attention_mask,
518
+ "bidirectional_mask": bidirectional_mask,
519
+ }
520
+
521
+ setattr(model, "forward", MethodType(forward, model))
522
+ setattr(
523
+ model,
524
+ "prepare_inputs_for_generation",
525
+ MethodType(prepare_inputs_for_generation, model),
526
+ )
527
+ setattr(model, "_prefix_lm_converted", True)
528
+ return model
529
+
530
+
531
+ def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM:
532
+ """Converts an OPT Causal LM to a Prefix LM.
533
+
534
+ Supported HuggingFace model classes:
535
+ - `OPTForCausalLM`
536
+
537
+ See `convert_hf_causal_lm_to_prefix_lm` for more details.
538
+ """
539
+ if hasattr(model, "_prefix_lm_converted"):
540
+ return model
541
+ assert isinstance(model, OPTForCausalLM)
542
+ assert (
543
+ model.config.add_cross_attention == False
544
+ ), "Only supports OPT decoder-only models"
545
+ setattr(model, "_original_forward", getattr(model, "forward"))
546
+ setattr(model, "_original_generate", getattr(model, "generate"))
547
+ model.model.decoder.bidirectional_mask = None
548
+
549
+ def _prepare_decoder_attention_mask(
550
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
551
+ ):
552
+ combined_attention_mask = None
553
+ if input_shape[-1] > 1:
554
+ if self.bidirectional_mask == "g":
555
+ (bsz, src_length) = input_shape
556
+ combined_attention_mask = torch.zeros(
557
+ (bsz, 1, src_length, src_length + past_key_values_length),
558
+ dtype=inputs_embeds.dtype,
559
+ device=inputs_embeds.device,
560
+ )
561
+ else:
562
+ combined_attention_mask = _make_causal_mask_opt(
563
+ input_shape,
564
+ inputs_embeds.dtype,
565
+ past_key_values_length=past_key_values_length,
566
+ ).to(inputs_embeds.device)
567
+ if self.bidirectional_mask is not None:
568
+ assert attention_mask.shape == self.bidirectional_mask.shape
569
+ expanded_bidirectional_mask = _expand_mask_opt(
570
+ self.bidirectional_mask,
571
+ inputs_embeds.dtype,
572
+ tgt_len=input_shape[-1],
573
+ ).to(inputs_embeds.device)
574
+ combined_attention_mask = torch.maximum(
575
+ expanded_bidirectional_mask, combined_attention_mask
576
+ )
577
+ if attention_mask is not None:
578
+ expanded_attn_mask = _expand_mask_opt(
579
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
580
+ ).to(inputs_embeds.device)
581
+ combined_attention_mask = (
582
+ expanded_attn_mask
583
+ if combined_attention_mask is None
584
+ else expanded_attn_mask + combined_attention_mask
585
+ )
586
+ return combined_attention_mask
587
+
588
+ setattr(
589
+ model.model.decoder,
590
+ "_prepare_decoder_attention_mask",
591
+ MethodType(_prepare_decoder_attention_mask, model.model.decoder),
592
+ )
593
+
594
+ def forward(
595
+ self: OPTForCausalLM,
596
+ input_ids: Optional[torch.LongTensor] = None,
597
+ attention_mask: Optional[torch.Tensor] = None,
598
+ bidirectional_mask: Optional[torch.ByteTensor] = None,
599
+ head_mask: Optional[torch.Tensor] = None,
600
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
601
+ inputs_embeds: Optional[torch.FloatTensor] = None,
602
+ labels: Optional[torch.LongTensor] = None,
603
+ use_cache: Optional[bool] = None,
604
+ output_attentions: Optional[bool] = None,
605
+ output_hidden_states: Optional[bool] = None,
606
+ return_dict: Optional[bool] = None,
607
+ ):
608
+ def call_og_forward():
609
+ return self._original_forward(
610
+ input_ids=input_ids,
611
+ attention_mask=attention_mask,
612
+ head_mask=head_mask,
613
+ past_key_values=past_key_values,
614
+ inputs_embeds=inputs_embeds,
615
+ labels=labels,
616
+ use_cache=use_cache,
617
+ output_attentions=output_attentions,
618
+ output_hidden_states=output_hidden_states,
619
+ return_dict=return_dict,
620
+ )
621
+
622
+ if bidirectional_mask is None:
623
+ return call_og_forward()
624
+ self.model.decoder.bidirectional_mask = bidirectional_mask
625
+ try:
626
+ outputs = call_og_forward()
627
+ except:
628
+ self.model.decoder.bidirectional_mask = None
629
+ raise
630
+ self.model.decoder.bidirectional_mask = None
631
+ return outputs
632
+
633
+ def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Dict[str, Any]):
634
+ """Wraps original generate to enable PrefixLM-style attention."""
635
+ self.model.decoder.bidirectional_mask = "g"
636
+ try:
637
+ output = self._original_generate(*args, **kwargs)
638
+ except:
639
+ self.model.decoder.bidirectional_mask = None
640
+ raise
641
+ self.model.decoder.bidirectional_mask = None
642
+ return output
643
+
644
+ setattr(model, "forward", MethodType(forward, model))
645
+ setattr(model, "generate", MethodType(generate, model))
646
+ setattr(model, "_prefix_lm_converted", True)
647
+ return model
648
+
649
+
650
+ _SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS + (BloomForCausalLM, OPTForCausalLM)
651
+ CAUSAL_LM_TYPES = Union[
652
+ GPT2LMHeadModel,
653
+ GPTJForCausalLM,
654
+ GPTNeoForCausalLM,
655
+ GPTNeoXForCausalLM,
656
+ BloomForCausalLM,
657
+ OPTForCausalLM,
658
+ ]
659
+
660
+
661
+ def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
662
+ """Converts a HuggingFace Causal LM to a Prefix LM.
663
+
664
+ Supported HuggingFace model classes:
665
+ - `GPT2LMHeadModel`
666
+ - `GPTNeoForCausalLM`
667
+ - `GPTNeoXForCausalLM`
668
+ - `GPTJForCausalLM`
669
+ - `BloomForCausalLM`
670
+ - `OPTForCausalLM`
671
+
672
+ Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the
673
+ `generate` method and/or select underlying methods depending on the model class.
674
+
675
+ These changes preserve the model API, but add a new input to `forward`: "bidirectional_mask".
676
+
677
+ Notes on training:
678
+ To actually train the converted model as a Prefix LM, training batches will need to indicate
679
+ the prefix/target structure by including `bidirectional_mask` as part of the batch inputs.
680
+
681
+ **This is not a standard input and requires custom layers either within or after your dataloader.**
682
+
683
+ In addition to adding `bidirectional_mask` to the batch, this custom code should modify `labels`
684
+ such that `batch['labels'][batch['bidirectional_mask'] == 1] == -100`.
685
+ That is, the prefix portion of the sequence should not generate any loss. Loss should only be
686
+ generated by the target portion of the sequence.
687
+
688
+ Notes on `GPTNeoForCausalLM`:
689
+ To simplify the implementation, "global" and "local" attention layers are handled differently.
690
+ For "global" layers, we handle conversion as described above. For "local" layers, which use a
691
+ causal attention mask within a restricted local window, we do not alter the masking.
692
+
693
+ Notes on `forward` method conversion:
694
+ After conversion, the `forward` method will handle a new input, `bidirectional_mask`,
695
+ which should be a [batch_size, seq_length] byte tensor, where 1 indicates token positions
696
+ belonging to the prefix (prefix tokens can attend to one another bidirectionally), and
697
+ 0 indicates token positions belonging to the target.
698
+
699
+ The new `forward` method will incorporate `bidirectional_mask` (if supplied) into the existing
700
+ causal mask, call the original `forward` method, and (if the causal mask is a buffer) reset
701
+ the causal masks before returning the result.
702
+
703
+ Notes on `generate` method conversion:
704
+ After conversion, the `generate` method will have the same signature but will internally
705
+ convert all causal masks to be purely bidirectional, call the original `generate` method, and
706
+ (where appropriate) reset the causal masks before returning the result.
707
+
708
+ This works thanks to the logic of the HuggingFace `generate` API, which first encodes the token
709
+ "prompt" passed to `generate` (which is treated as the prefix) and then sequentially generates
710
+ each new token. Encodings are cached as generation happens, so all prefix tokens can attend to one
711
+ another (as expected in a Prefix LM) and generated tokens can only attend to prefix tokens and
712
+ previously-generated tokens (also as expected in a Prefix LM).
713
+
714
+ To preserve the API, the original methods are renamed to `_original_forward` and
715
+ `_original_generate`, and replaced with new `forward` and `generate` methods that wrap
716
+ them, respectively. Although implementation details vary by model class.
717
+ """
718
+ if isinstance(model, _SUPPORTED_GPT_MODELS):
719
+ return _convert_gpt_causal_lm_to_prefix_lm(model)
720
+ elif isinstance(model, BloomForCausalLM):
721
+ return _convert_bloom_causal_lm_to_prefix_lm(model)
722
+ elif isinstance(model, OPTForCausalLM):
723
+ return _convert_opt_causal_lm_to_prefix_lm(model)
724
+ else:
725
+ raise TypeError(
726
+ f"Cannot convert model to Prefix LM. "
727
+ + f"Model does not belong to set of supported HF models:"
728
+ + f"\n{_SUPPORTED_HF_MODELS}"
729
+ )
730
+
731
+
732
+ def add_bidirectional_mask_if_missing(batch: Dict[str, Any]):
733
+ """Attempts to add bidirectional_mask to batch if missing.
734
+
735
+ Raises:
736
+ KeyError if bidirectional_mask is missing and can't be inferred
737
+ """
738
+ if "bidirectional_mask" not in batch:
739
+ if batch.get("mode", None) == "icl_task":
740
+ batch["bidirectional_mask"] = batch["attention_mask"].clone()
741
+ for i, continuation_indices in enumerate(batch["continuation_indices"]):
742
+ batch["bidirectional_mask"][i, continuation_indices] = 0
743
+ elif "labels" in batch and "attention_mask" in batch:
744
+ batch["bidirectional_mask"] = torch.logical_and(
745
+ torch.eq(batch["attention_mask"], 1), torch.eq(batch["labels"], -100)
746
+ ).type_as(batch["attention_mask"])
747
+ else:
748
+ raise KeyError(
749
+ "No bidirectional_mask in batch and not sure how to construct one."
750
+ )
lisa_on_cuda/llava/model/language_model/mpt/meta_init_context.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ @contextmanager
8
+ def init_empty_weights(include_buffers: bool = False):
9
+ """Meta initialization context manager.
10
+
11
+ A context manager under which models are initialized with all parameters
12
+ on the meta device, therefore creating an empty model. Useful when just
13
+ initializing the model would blow the available RAM.
14
+
15
+ Args:
16
+ include_buffers (`bool`, *optional*, defaults to `False`): Whether or
17
+ not to also put all buffers on the meta device while initializing.
18
+
19
+ Example:
20
+ ```python
21
+ import torch.nn as nn
22
+
23
+ # Initialize a model with 100 billions parameters in no time and without using any RAM.
24
+ with init_empty_weights():
25
+ tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
26
+ ```
27
+
28
+ <Tip warning={true}>
29
+
30
+ Any model created under this context manager has no weights. As such you can't do something like
31
+ `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
32
+
33
+ </Tip>
34
+ """
35
+ with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
36
+ yield f
37
+
38
+
39
+ @contextmanager
40
+ def init_on_device(device: torch.device, include_buffers: bool = False):
41
+ """Device initialization context manager.
42
+
43
+ A context manager under which models are initialized with all parameters
44
+ on the specified device.
45
+
46
+ Args:
47
+ device (`torch.device`): Device to initialize all parameters on.
48
+ include_buffers (`bool`, *optional*, defaults to `False`): Whether or
49
+ not to also put all buffers on the meta device while initializing.
50
+
51
+ Example:
52
+ ```python
53
+ import torch.nn as nn
54
+
55
+ with init_on_device(device=torch.device("cuda")):
56
+ tst = nn.Liner(100, 100) # on `cuda` device
57
+ ```
58
+ """
59
+ old_register_parameter = nn.Module.register_parameter
60
+ if include_buffers:
61
+ old_register_buffer = nn.Module.register_buffer
62
+
63
+ def register_empty_parameter(module, name, param):
64
+ old_register_parameter(module, name, param)
65
+ if param is not None:
66
+ param_cls = type(module._parameters[name])
67
+ kwargs = module._parameters[name].__dict__
68
+ module._parameters[name] = param_cls(
69
+ module._parameters[name].to(device), **kwargs
70
+ )
71
+
72
+ def register_empty_buffer(module, name, buffer):
73
+ old_register_buffer(module, name, buffer)
74
+ if buffer is not None:
75
+ module._buffers[name] = module._buffers[name].to(device)
76
+
77
+ if include_buffers:
78
+ tensor_constructors_to_patch = {
79
+ torch_function_name: getattr(torch, torch_function_name)
80
+ for torch_function_name in ["empty", "zeros", "ones", "full"]
81
+ }
82
+ else:
83
+ tensor_constructors_to_patch = {}
84
+
85
+ def patch_tensor_constructor(fn):
86
+ def wrapper(*args, **kwargs):
87
+ kwargs["device"] = device
88
+ return fn(*args, **kwargs)
89
+
90
+ return wrapper
91
+
92
+ try:
93
+ nn.Module.register_parameter = register_empty_parameter
94
+ if include_buffers:
95
+ nn.Module.register_buffer = register_empty_buffer
96
+ for torch_function_name in tensor_constructors_to_patch.keys():
97
+ setattr(
98
+ torch,
99
+ torch_function_name,
100
+ patch_tensor_constructor(getattr(torch, torch_function_name)),
101
+ )
102
+ yield
103
+ finally:
104
+ nn.Module.register_parameter = old_register_parameter
105
+ if include_buffers:
106
+ nn.Module.register_buffer = old_register_buffer
107
+ for (
108
+ torch_function_name,
109
+ old_torch_function,
110
+ ) in tensor_constructors_to_patch.items():
111
+ setattr(torch, torch_function_name, old_torch_function)
lisa_on_cuda/llava/model/language_model/mpt/modeling_mpt.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A simple, flexible implementation of a GPT model.
2
+
3
+ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
4
+ """
5
+ import math
6
+ import warnings
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from transformers import (PreTrainedModel, PreTrainedTokenizer,
13
+ PreTrainedTokenizerFast)
14
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
15
+ CausalLMOutputWithPast)
16
+
17
+ from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
18
+ from .attention import attn_bias_shape, build_attn_bias
19
+ from .blocks import MPTBlock
20
+ from .configuration_mpt import MPTConfig
21
+ from .custom_embedding import SharedEmbedding
22
+ from .hf_prefixlm_converter import (add_bidirectional_mask_if_missing,
23
+ convert_hf_causal_lm_to_prefix_lm)
24
+ from .meta_init_context import init_empty_weights
25
+ from .norm import NORM_CLASS_REGISTRY
26
+ from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
27
+
28
+ try:
29
+ from .flash_attn_triton import flash_attn_func
30
+ except:
31
+ pass
32
+ Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
33
+
34
+
35
+ class MPTPreTrainedModel(PreTrainedModel):
36
+ config_class = MPTConfig
37
+ base_model_prefix = "model"
38
+ _no_split_modules = ["MPTBlock"]
39
+
40
+
41
+ class MPTModel(MPTPreTrainedModel):
42
+ def __init__(self, config: MPTConfig):
43
+ config._validate_config()
44
+ super().__init__(config)
45
+ self.attn_impl = config.attn_config["attn_impl"]
46
+ self.prefix_lm = config.attn_config["prefix_lm"]
47
+ self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"]
48
+ self.alibi = config.attn_config["alibi"]
49
+ self.alibi_bias_max = config.attn_config["alibi_bias_max"]
50
+ if config.init_device == "mixed":
51
+ if dist.get_local_rank() == 0:
52
+ config.init_device = "cpu"
53
+ else:
54
+ config.init_device = "meta"
55
+ if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
56
+ norm_options = " | ".join(NORM_CLASS_REGISTRY.keys())
57
+ raise NotImplementedError(
58
+ f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options})."
59
+ )
60
+ norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
61
+ self.embedding_fraction = config.embedding_fraction
62
+ self.wte = SharedEmbedding(
63
+ config.vocab_size, config.d_model, device=config.init_device
64
+ )
65
+ if not self.alibi:
66
+ self.wpe = torch.nn.Embedding(
67
+ config.max_seq_len, config.d_model, device=config.init_device
68
+ )
69
+ self.emb_drop = nn.Dropout(config.emb_pdrop)
70
+ self.blocks = nn.ModuleList(
71
+ [
72
+ MPTBlock(device=config.init_device, **config.to_dict())
73
+ for _ in range(config.n_layers)
74
+ ]
75
+ )
76
+ self.norm_f = norm_class(config.d_model, device=config.init_device)
77
+ if config.init_device != "meta":
78
+ print(
79
+ f'You are using config.init_device={config.init_device!r}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.'
80
+ )
81
+ self.apply(self.param_init_fn)
82
+ self.is_causal = not self.prefix_lm
83
+ self._attn_bias_initialized = False
84
+ self.attn_bias = None
85
+ self.attn_bias_shape = attn_bias_shape(
86
+ self.attn_impl,
87
+ config.n_heads,
88
+ config.max_seq_len,
89
+ self.alibi,
90
+ prefix_lm=self.prefix_lm,
91
+ causal=self.is_causal,
92
+ use_sequence_id=self.attn_uses_sequence_id,
93
+ )
94
+ if config.no_bias:
95
+ for module in self.modules():
96
+ if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
97
+ if config.verbose:
98
+ warnings.warn(f"Removing bias ({module.bias}) from {module}.")
99
+ module.register_parameter("bias", None)
100
+ if config.verbose and config.verbose > 2:
101
+ print(self)
102
+ if "verbose" not in self.config.init_config:
103
+ self.config.init_config["verbose"] = self.config.verbose
104
+ if self.config.init_config["verbose"] > 1:
105
+ init_fn_name = self.config.init_config["name"]
106
+ warnings.warn(f"Using {init_fn_name} initialization.")
107
+ self.gradient_checkpointing = False
108
+
109
+ def get_input_embeddings(self):
110
+ return self.wte
111
+
112
+ def set_input_embeddings(self, value):
113
+ self.wte = value
114
+
115
+ @torch.no_grad()
116
+ def _attn_bias(
117
+ self,
118
+ device,
119
+ dtype,
120
+ attention_mask: Optional[torch.ByteTensor] = None,
121
+ prefix_mask: Optional[torch.ByteTensor] = None,
122
+ sequence_id: Optional[torch.LongTensor] = None,
123
+ ):
124
+ if not self._attn_bias_initialized:
125
+ if self.attn_bias_shape:
126
+ self.attn_bias = torch.zeros(
127
+ self.attn_bias_shape, device=device, dtype=dtype
128
+ )
129
+ self.attn_bias = build_attn_bias(
130
+ self.attn_impl,
131
+ self.attn_bias,
132
+ self.config.n_heads,
133
+ self.config.max_seq_len,
134
+ causal=self.is_causal,
135
+ alibi=self.alibi,
136
+ alibi_bias_max=self.alibi_bias_max,
137
+ )
138
+ self._attn_bias_initialized = True
139
+ if self.attn_impl == "flash":
140
+ return (self.attn_bias, attention_mask)
141
+ if self.attn_bias is not None:
142
+ self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
143
+ attn_bias = self.attn_bias
144
+ if self.prefix_lm:
145
+ assert isinstance(attn_bias, torch.Tensor)
146
+ assert isinstance(prefix_mask, torch.Tensor)
147
+ attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
148
+ if self.attn_uses_sequence_id and sequence_id is not None:
149
+ assert isinstance(attn_bias, torch.Tensor)
150
+ attn_bias = self._apply_sequence_id(attn_bias, sequence_id)
151
+ if attention_mask is not None:
152
+ s_k = attention_mask.shape[-1]
153
+ if attn_bias is None:
154
+ attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
155
+ else:
156
+ _s_k = max(0, attn_bias.size(-1) - s_k)
157
+ attn_bias = attn_bias[:, :, :, _s_k:]
158
+ if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
159
+ raise ValueError(
160
+ f"attention_mask shape={attention_mask.shape} "
161
+ + f"and prefix_mask shape={prefix_mask.shape} are not equal."
162
+ )
163
+ min_val = torch.finfo(attn_bias.dtype).min
164
+ attn_bias = attn_bias.masked_fill(
165
+ ~attention_mask.view(-1, 1, 1, s_k), min_val
166
+ )
167
+ return (attn_bias, None)
168
+
169
+ def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
170
+ (s_k, s_q) = attn_bias.shape[-2:]
171
+ if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
172
+ raise ValueError(
173
+ "attn_bias does not match the expected shape. "
174
+ + f"The last two dimensions should both be {self.config.max_length} "
175
+ + f"but are {s_k} and {s_q}."
176
+ )
177
+ seq_len = prefix_mask.shape[-1]
178
+ if seq_len > self.config.max_seq_len:
179
+ raise ValueError(
180
+ f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}"
181
+ )
182
+ attn_bias = attn_bias[..., :seq_len, :seq_len]
183
+ causal = torch.tril(
184
+ torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)
185
+ ).view(1, 1, seq_len, seq_len)
186
+ prefix = prefix_mask.view(-1, 1, 1, seq_len)
187
+ cannot_attend = ~torch.logical_or(causal, prefix.bool())
188
+ min_val = torch.finfo(attn_bias.dtype).min
189
+ attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
190
+ return attn_bias
191
+
192
+ def _apply_sequence_id(
193
+ self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor
194
+ ):
195
+ seq_len = sequence_id.shape[-1]
196
+ if seq_len > self.config.max_seq_len:
197
+ raise ValueError(
198
+ f"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}"
199
+ )
200
+ attn_bias = attn_bias[..., :seq_len, :seq_len]
201
+ cannot_attend = torch.logical_not(
202
+ torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))
203
+ ).unsqueeze(1)
204
+ min_val = torch.finfo(attn_bias.dtype).min
205
+ attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
206
+ return attn_bias
207
+
208
+ def forward(
209
+ self,
210
+ input_ids: torch.LongTensor,
211
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
212
+ attention_mask: Optional[torch.ByteTensor] = None,
213
+ prefix_mask: Optional[torch.ByteTensor] = None,
214
+ sequence_id: Optional[torch.LongTensor] = None,
215
+ return_dict: Optional[bool] = None,
216
+ output_attentions: Optional[bool] = None,
217
+ output_hidden_states: Optional[bool] = None,
218
+ use_cache: Optional[bool] = None,
219
+ inputs_embeds: Optional[torch.Tensor] = None,
220
+ ):
221
+ return_dict = (
222
+ return_dict if return_dict is not None else self.config.return_dict
223
+ )
224
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
225
+ if attention_mask is not None:
226
+ attention_mask = attention_mask.bool()
227
+ if prefix_mask is not None:
228
+ prefix_mask = prefix_mask.bool()
229
+ if not return_dict:
230
+ raise NotImplementedError(
231
+ "return_dict False is not implemented yet for MPT"
232
+ )
233
+ if output_attentions:
234
+ if self.attn_impl != "torch":
235
+ raise NotImplementedError(
236
+ "output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`."
237
+ )
238
+ if (
239
+ attention_mask is not None
240
+ and attention_mask[:, 0].sum() != attention_mask.shape[0]
241
+ and self.training
242
+ ):
243
+ raise NotImplementedError(
244
+ "MPT does not support training with left padding."
245
+ )
246
+ if self.prefix_lm and prefix_mask is None:
247
+ raise ValueError(
248
+ "prefix_mask is a required argument when MPT is configured with prefix_lm=True."
249
+ )
250
+ if self.training:
251
+ if self.attn_uses_sequence_id and sequence_id is None:
252
+ raise ValueError(
253
+ "sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True "
254
+ + "and the model is in train mode."
255
+ )
256
+ elif self.attn_uses_sequence_id is False and sequence_id is not None:
257
+ warnings.warn(
258
+ "MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. "
259
+ + "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True."
260
+ )
261
+ if input_ids is not None:
262
+ S = input_ids.size(1)
263
+ assert (
264
+ S <= self.config.max_seq_len
265
+ ), f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}"
266
+ tok_emb = self.wte(input_ids)
267
+ else:
268
+ assert inputs_embeds is not None
269
+ assert (
270
+ self.alibi
271
+ ), "inputs_embeds is not implemented for MPT unless for alibi."
272
+ S = inputs_embeds.size(1)
273
+ tok_emb = inputs_embeds
274
+ if self.alibi:
275
+ x = tok_emb
276
+ else:
277
+ past_position = 0
278
+ if past_key_values is not None:
279
+ if len(past_key_values) != self.config.n_layers:
280
+ raise ValueError(
281
+ f"past_key_values must provide a past_key_value for each attention "
282
+ + f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})."
283
+ )
284
+ past_position = past_key_values[0][0].size(1)
285
+ if self.attn_impl == "torch":
286
+ past_position = past_key_values[0][0].size(3)
287
+ if S + past_position > self.config.max_seq_len:
288
+ raise ValueError(
289
+ f"Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}."
290
+ )
291
+ pos = torch.arange(
292
+ past_position,
293
+ S + past_position,
294
+ dtype=torch.long,
295
+ device=input_ids.device,
296
+ ).unsqueeze(0)
297
+ if attention_mask is not None:
298
+ pos = torch.clamp(
299
+ pos
300
+ - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[
301
+ :, past_position:
302
+ ],
303
+ min=0,
304
+ )
305
+ pos_emb = self.wpe(pos)
306
+ x = tok_emb + pos_emb
307
+ if self.embedding_fraction == 1:
308
+ x = self.emb_drop(x)
309
+ else:
310
+ x_shrunk = x * self.embedding_fraction + x.detach() * (
311
+ 1 - self.embedding_fraction
312
+ )
313
+ assert isinstance(self.emb_drop, nn.Module)
314
+ x = self.emb_drop(x_shrunk)
315
+ (attn_bias, attention_mask) = self._attn_bias(
316
+ device=x.device,
317
+ dtype=torch.float32,
318
+ attention_mask=attention_mask,
319
+ prefix_mask=prefix_mask,
320
+ sequence_id=sequence_id,
321
+ )
322
+ if use_cache and past_key_values is None:
323
+ past_key_values = [() for _ in range(self.config.n_layers)]
324
+ all_hidden_states = () if output_hidden_states else None
325
+ all_self_attns = () if output_attentions else None
326
+ for b_idx, block in enumerate(self.blocks):
327
+ if output_hidden_states:
328
+ assert all_hidden_states is not None
329
+ all_hidden_states = all_hidden_states + (x,)
330
+ past_key_value = (
331
+ past_key_values[b_idx] if past_key_values is not None else None
332
+ )
333
+ if self.gradient_checkpointing and self.training:
334
+ (x, attn_weights, past_key_value) = torch.utils.checkpoint.checkpoint(
335
+ block, x, past_key_value, attn_bias, attention_mask, self.is_causal
336
+ )
337
+ else:
338
+ (x, attn_weights, past_key_value) = block(
339
+ x,
340
+ past_key_value=past_key_value,
341
+ attn_bias=attn_bias,
342
+ attention_mask=attention_mask,
343
+ is_causal=self.is_causal,
344
+ )
345
+ if past_key_values is not None:
346
+ past_key_values[b_idx] = past_key_value
347
+ if output_attentions:
348
+ assert all_self_attns is not None
349
+ all_self_attns = all_self_attns + (attn_weights,)
350
+ x = self.norm_f(x)
351
+ if output_hidden_states:
352
+ assert all_hidden_states is not None
353
+ all_hidden_states = all_hidden_states + (x,)
354
+ return BaseModelOutputWithPast(
355
+ last_hidden_state=x,
356
+ past_key_values=past_key_values,
357
+ hidden_states=all_hidden_states,
358
+ attentions=all_self_attns,
359
+ )
360
+
361
+ def param_init_fn(self, module):
362
+ init_fn_name = self.config.init_config["name"]
363
+ MODEL_INIT_REGISTRY[init_fn_name](
364
+ module=module,
365
+ n_layers=self.config.n_layers,
366
+ d_model=self.config.d_model,
367
+ **self.config.init_config,
368
+ )
369
+
370
+ def fsdp_wrap_fn(self, module):
371
+ return isinstance(module, MPTBlock)
372
+
373
+ def activation_checkpointing_fn(self, module):
374
+ return isinstance(module, MPTBlock)
375
+
376
+
377
+ class MPTForCausalLM(MPTPreTrainedModel):
378
+ def __init__(self, config: MPTConfig):
379
+ super().__init__(config)
380
+ if not config.tie_word_embeddings:
381
+ raise ValueError("MPTForCausalLM only supports tied word embeddings")
382
+ print(f"Instantiating an MPTForCausalLM model from {__file__}")
383
+ self.transformer = MPTModel(config)
384
+ for child in self.transformer.children():
385
+ if isinstance(child, torch.nn.ModuleList):
386
+ continue
387
+ if isinstance(child, torch.nn.Module):
388
+ child._fsdp_wrap = True
389
+ self.logit_scale = None
390
+ if config.logit_scale is not None:
391
+ logit_scale = config.logit_scale
392
+ if isinstance(logit_scale, str):
393
+ if logit_scale == "inv_sqrt_d_model":
394
+ logit_scale = 1 / math.sqrt(config.d_model)
395
+ else:
396
+ raise ValueError(
397
+ f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
398
+ )
399
+ self.logit_scale = logit_scale
400
+
401
+ def get_input_embeddings(self):
402
+ return self.transformer.wte
403
+
404
+ def set_input_embeddings(self, value):
405
+ self.transformer.wte = value
406
+
407
+ def get_output_embeddings(self):
408
+ return self.transformer.wte
409
+
410
+ def set_output_embeddings(self, new_embeddings):
411
+ self.transformer.wte = new_embeddings
412
+
413
+ def set_decoder(self, decoder):
414
+ self.transformer = decoder
415
+
416
+ def get_decoder(self):
417
+ return self.transformer
418
+
419
+ def forward(
420
+ self,
421
+ input_ids: torch.LongTensor,
422
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
423
+ attention_mask: Optional[torch.ByteTensor] = None,
424
+ prefix_mask: Optional[torch.ByteTensor] = None,
425
+ sequence_id: Optional[torch.LongTensor] = None,
426
+ labels: Optional[torch.LongTensor] = None,
427
+ return_dict: Optional[bool] = None,
428
+ output_attentions: Optional[bool] = None,
429
+ output_hidden_states: Optional[bool] = None,
430
+ use_cache: Optional[bool] = None,
431
+ inputs_embeds: Optional[torch.FloatTensor] = None,
432
+ ):
433
+ return_dict = (
434
+ return_dict if return_dict is not None else self.config.return_dict
435
+ )
436
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
437
+ if inputs_embeds is not None:
438
+ raise NotImplementedError(
439
+ "inputs_embeds has to be None (for hf/peft support)."
440
+ )
441
+ outputs = self.transformer(
442
+ input_ids=input_ids,
443
+ past_key_values=past_key_values,
444
+ attention_mask=attention_mask,
445
+ prefix_mask=prefix_mask,
446
+ sequence_id=sequence_id,
447
+ return_dict=return_dict,
448
+ output_attentions=output_attentions,
449
+ output_hidden_states=output_hidden_states,
450
+ use_cache=use_cache,
451
+ )
452
+ logits = self.transformer.wte(
453
+ outputs.last_hidden_state.to(self.transformer.wte.weight.device), True
454
+ )
455
+ if self.logit_scale is not None:
456
+ if self.logit_scale == 0:
457
+ warnings.warn(
458
+ f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs."
459
+ )
460
+ logits *= self.logit_scale
461
+ loss = None
462
+ if labels is not None:
463
+ labels = torch.roll(labels, shifts=-1)
464
+ labels[:, -1] = -100
465
+ loss = F.cross_entropy(
466
+ logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)
467
+ )
468
+ return CausalLMOutputWithPast(
469
+ loss=loss,
470
+ logits=logits,
471
+ past_key_values=outputs.past_key_values,
472
+ hidden_states=outputs.hidden_states,
473
+ attentions=outputs.attentions,
474
+ )
475
+
476
+ def param_init_fn(self, module):
477
+ init_fn_name = self.config.init_config["name"]
478
+ MODEL_INIT_REGISTRY[init_fn_name](
479
+ module=module,
480
+ n_layers=self.config.n_layers,
481
+ d_model=self.config.d_model,
482
+ **self.config.init_config,
483
+ )
484
+
485
+ def fsdp_wrap_fn(self, module):
486
+ return isinstance(module, MPTBlock)
487
+
488
+ def activation_checkpointing_fn(self, module):
489
+ return isinstance(module, MPTBlock)
490
+
491
+ def prepare_inputs_for_generation(
492
+ self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
493
+ ):
494
+ if inputs_embeds is not None:
495
+ raise NotImplementedError("inputs_embeds is not implemented for MPT yet")
496
+ attention_mask = kwargs["attention_mask"].bool()
497
+ if attention_mask[:, -1].sum() != attention_mask.shape[0]:
498
+ raise NotImplementedError(
499
+ "MPT does not support generation with right padding."
500
+ )
501
+ if self.transformer.attn_uses_sequence_id and self.training:
502
+ sequence_id = torch.zeros_like(input_ids[:1])
503
+ else:
504
+ sequence_id = None
505
+ if past_key_values is not None:
506
+ input_ids = input_ids[:, -1].unsqueeze(-1)
507
+ if self.transformer.prefix_lm:
508
+ prefix_mask = torch.ones_like(attention_mask)
509
+ if kwargs.get("use_cache") == False:
510
+ raise NotImplementedError(
511
+ "MPT with prefix_lm=True does not support use_cache=False."
512
+ )
513
+ else:
514
+ prefix_mask = None
515
+ return {
516
+ "input_ids": input_ids,
517
+ "attention_mask": attention_mask,
518
+ "prefix_mask": prefix_mask,
519
+ "sequence_id": sequence_id,
520
+ "past_key_values": past_key_values,
521
+ "use_cache": kwargs.get("use_cache", True),
522
+ }
523
+
524
+ @staticmethod
525
+ def _reorder_cache(past_key_values, beam_idx):
526
+ """Used by HuggingFace generate when using beam search with kv-caching.
527
+
528
+ See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
529
+ for an example in transformers.
530
+ """
531
+ reordered_past = []
532
+ for layer_past in past_key_values:
533
+ reordered_past += [
534
+ tuple(
535
+ (past_state.index_select(0, beam_idx) for past_state in layer_past)
536
+ )
537
+ ]
538
+ return reordered_past
lisa_on_cuda/llava/model/language_model/mpt/norm.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def _cast_if_autocast_enabled(tensor):
5
+ if torch.is_autocast_enabled():
6
+ if tensor.device.type == "cuda":
7
+ dtype = torch.get_autocast_gpu_dtype()
8
+ elif tensor.device.type == "cpu":
9
+ dtype = torch.get_autocast_cpu_dtype()
10
+ else:
11
+ raise NotImplementedError()
12
+ return tensor.to(dtype=dtype)
13
+ return tensor
14
+
15
+
16
+ class LPLayerNorm(torch.nn.LayerNorm):
17
+ def __init__(
18
+ self,
19
+ normalized_shape,
20
+ eps=1e-05,
21
+ elementwise_affine=True,
22
+ device=None,
23
+ dtype=None,
24
+ ):
25
+ super().__init__(
26
+ normalized_shape=normalized_shape,
27
+ eps=eps,
28
+ elementwise_affine=elementwise_affine,
29
+ device=device,
30
+ dtype=dtype,
31
+ )
32
+
33
+ def forward(self, x):
34
+ module_device = x.device
35
+ downcast_x = _cast_if_autocast_enabled(x)
36
+ downcast_weight = (
37
+ _cast_if_autocast_enabled(self.weight)
38
+ if self.weight is not None
39
+ else self.weight
40
+ )
41
+ downcast_bias = (
42
+ _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
43
+ )
44
+ with torch.autocast(enabled=False, device_type=module_device.type):
45
+ return torch.nn.functional.layer_norm(
46
+ downcast_x,
47
+ self.normalized_shape,
48
+ downcast_weight,
49
+ downcast_bias,
50
+ self.eps,
51
+ )
52
+
53
+
54
+ def rms_norm(x, weight=None, eps=1e-05):
55
+ output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
56
+ if weight is not None:
57
+ return output * weight
58
+ return output
59
+
60
+
61
+ class RMSNorm(torch.nn.Module):
62
+ def __init__(
63
+ self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None
64
+ ):
65
+ super().__init__()
66
+ self.eps = eps
67
+ if weight:
68
+ self.weight = torch.nn.Parameter(
69
+ torch.ones(normalized_shape, dtype=dtype, device=device)
70
+ )
71
+ else:
72
+ self.register_parameter("weight", None)
73
+
74
+ def forward(self, x):
75
+ return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
76
+
77
+
78
+ class LPRMSNorm(RMSNorm):
79
+ def __init__(
80
+ self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None
81
+ ):
82
+ super().__init__(
83
+ normalized_shape=normalized_shape,
84
+ eps=eps,
85
+ weight=weight,
86
+ dtype=dtype,
87
+ device=device,
88
+ )
89
+
90
+ def forward(self, x):
91
+ downcast_x = _cast_if_autocast_enabled(x)
92
+ downcast_weight = (
93
+ _cast_if_autocast_enabled(self.weight)
94
+ if self.weight is not None
95
+ else self.weight
96
+ )
97
+ with torch.autocast(enabled=False, device_type=x.device.type):
98
+ return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
99
+
100
+
101
+ NORM_CLASS_REGISTRY = {
102
+ "layernorm": torch.nn.LayerNorm,
103
+ "low_precision_layernorm": LPLayerNorm,
104
+ "rmsnorm": RMSNorm,
105
+ "low_precision_rmsnorm": LPRMSNorm,
106
+ }
lisa_on_cuda/llava/model/language_model/mpt/param_init_fns.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from collections.abc import Sequence
4
+ from functools import partial
5
+ from typing import Optional, Tuple, Union
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+ from .norm import NORM_CLASS_REGISTRY
11
+
12
+
13
+ def torch_default_param_init_fn_(module: nn.Module, verbose: int = 0, **kwargs):
14
+ del kwargs
15
+ if verbose > 1:
16
+ warnings.warn(f"Initializing network using module's reset_parameters attribute")
17
+ if hasattr(module, "reset_parameters"):
18
+ module.reset_parameters()
19
+
20
+
21
+ def fused_init_helper_(module: nn.Module, init_fn_):
22
+ _fused = getattr(module, "_fused", None)
23
+ if _fused is None:
24
+ raise RuntimeError(f"Internal logic error")
25
+ (dim, splits) = _fused
26
+ splits = (0, *splits, module.weight.size(dim))
27
+ for s, e in zip(splits[:-1], splits[1:]):
28
+ slice_indices = [slice(None)] * module.weight.ndim
29
+ slice_indices[dim] = slice(s, e)
30
+ init_fn_(module.weight[slice_indices])
31
+
32
+
33
+ def generic_param_init_fn_(
34
+ module: nn.Module,
35
+ init_fn_,
36
+ n_layers: int,
37
+ d_model: Optional[int] = None,
38
+ init_div_is_residual: Union[int, float, str, bool] = True,
39
+ emb_init_std: Optional[float] = None,
40
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
41
+ verbose: int = 0,
42
+ **kwargs,
43
+ ):
44
+ del kwargs
45
+ if verbose > 1:
46
+ warnings.warn(f"If model has bias parameters they are initialized to 0.")
47
+ init_div_is_residual = init_div_is_residual
48
+ if init_div_is_residual is False:
49
+ div_is_residual = 1.0
50
+ elif init_div_is_residual is True:
51
+ div_is_residual = math.sqrt(2 * n_layers)
52
+ elif isinstance(init_div_is_residual, float) or isinstance(
53
+ init_div_is_residual, int
54
+ ):
55
+ div_is_residual = init_div_is_residual
56
+ elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric():
57
+ div_is_residual = float(init_div_is_residual)
58
+ else:
59
+ div_is_residual = 1.0
60
+ raise ValueError(
61
+ f"Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}"
62
+ )
63
+ if init_div_is_residual is not False:
64
+ if verbose > 1:
65
+ warnings.warn(
66
+ f"Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. "
67
+ + f"Set `init_div_is_residual: false` in init config to disable this."
68
+ )
69
+ if isinstance(module, nn.Linear):
70
+ if hasattr(module, "_fused"):
71
+ fused_init_helper_(module, init_fn_)
72
+ else:
73
+ init_fn_(module.weight)
74
+ if module.bias is not None:
75
+ torch.nn.init.zeros_(module.bias)
76
+ if init_div_is_residual is not False and getattr(module, "_is_residual", False):
77
+ with torch.no_grad():
78
+ module.weight.div_(div_is_residual)
79
+ elif isinstance(module, nn.Embedding):
80
+ if emb_init_std is not None:
81
+ std = emb_init_std
82
+ if std == 0:
83
+ warnings.warn(f"Embedding layer initialized to 0.")
84
+ emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
85
+ if verbose > 1:
86
+ warnings.warn(
87
+ f"Embedding layer initialized using normal distribution with mean=0 and std={std!r}."
88
+ )
89
+ elif emb_init_uniform_lim is not None:
90
+ lim = emb_init_uniform_lim
91
+ if isinstance(lim, Sequence):
92
+ if len(lim) > 2:
93
+ raise ValueError(
94
+ f"Uniform init requires a min and a max limit. User input: {lim}."
95
+ )
96
+ if lim[0] == lim[1]:
97
+ warnings.warn(f"Embedding layer initialized to {lim[0]}.")
98
+ else:
99
+ if lim == 0:
100
+ warnings.warn(f"Embedding layer initialized to 0.")
101
+ lim = [-lim, lim]
102
+ (a, b) = lim
103
+ emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
104
+ if verbose > 1:
105
+ warnings.warn(
106
+ f"Embedding layer initialized using uniform distribution in range {lim}."
107
+ )
108
+ else:
109
+ emb_init_fn_ = init_fn_
110
+ emb_init_fn_(module.weight)
111
+ elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
112
+ if verbose > 1:
113
+ warnings.warn(
114
+ f"Norm weights are set to 1. If norm layer has a bias it is initialized to 0."
115
+ )
116
+ if hasattr(module, "weight") and module.weight is not None:
117
+ torch.nn.init.ones_(module.weight)
118
+ if hasattr(module, "bias") and module.bias is not None:
119
+ torch.nn.init.zeros_(module.bias)
120
+ elif isinstance(module, nn.MultiheadAttention):
121
+ if module._qkv_same_embed_dim:
122
+ assert module.in_proj_weight is not None
123
+ assert (
124
+ module.q_proj_weight is None
125
+ and module.k_proj_weight is None
126
+ and (module.v_proj_weight is None)
127
+ )
128
+ assert d_model is not None
129
+ _d = d_model
130
+ splits = (0, _d, 2 * _d, 3 * _d)
131
+ for s, e in zip(splits[:-1], splits[1:]):
132
+ init_fn_(module.in_proj_weight[s:e])
133
+ else:
134
+ assert (
135
+ module.q_proj_weight is not None
136
+ and module.k_proj_weight is not None
137
+ and (module.v_proj_weight is not None)
138
+ )
139
+ assert module.in_proj_weight is None
140
+ init_fn_(module.q_proj_weight)
141
+ init_fn_(module.k_proj_weight)
142
+ init_fn_(module.v_proj_weight)
143
+ if module.in_proj_bias is not None:
144
+ torch.nn.init.zeros_(module.in_proj_bias)
145
+ if module.bias_k is not None:
146
+ torch.nn.init.zeros_(module.bias_k)
147
+ if module.bias_v is not None:
148
+ torch.nn.init.zeros_(module.bias_v)
149
+ init_fn_(module.out_proj.weight)
150
+ if init_div_is_residual is not False and getattr(
151
+ module.out_proj, "_is_residual", False
152
+ ):
153
+ with torch.no_grad():
154
+ module.out_proj.weight.div_(div_is_residual)
155
+ if module.out_proj.bias is not None:
156
+ torch.nn.init.zeros_(module.out_proj.bias)
157
+ else:
158
+ for _ in module.parameters(recurse=False):
159
+ raise NotImplementedError(
160
+ f"{module.__class__.__name__} parameters are not initialized by param_init_fn."
161
+ )
162
+
163
+
164
+ def _normal_init_(std, mean=0.0):
165
+ return partial(torch.nn.init.normal_, mean=mean, std=std)
166
+
167
+
168
+ def _normal_param_init_fn_(
169
+ module: nn.Module,
170
+ std: float,
171
+ n_layers: int,
172
+ d_model: Optional[int] = None,
173
+ init_div_is_residual: Union[int, float, str, bool] = True,
174
+ emb_init_std: Optional[float] = None,
175
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
176
+ verbose: int = 0,
177
+ **kwargs,
178
+ ):
179
+ del kwargs
180
+ init_fn_ = _normal_init_(std=std)
181
+ if verbose > 1:
182
+ warnings.warn(f"Using torch.nn.init.normal_ init fn mean=0.0, std={std}")
183
+ generic_param_init_fn_(
184
+ module=module,
185
+ init_fn_=init_fn_,
186
+ d_model=d_model,
187
+ n_layers=n_layers,
188
+ init_div_is_residual=init_div_is_residual,
189
+ emb_init_std=emb_init_std,
190
+ emb_init_uniform_lim=emb_init_uniform_lim,
191
+ verbose=verbose,
192
+ )
193
+
194
+
195
+ def baseline_param_init_fn_(
196
+ module: nn.Module,
197
+ init_std: float,
198
+ n_layers: int,
199
+ d_model: Optional[int] = None,
200
+ init_div_is_residual: Union[int, float, str, bool] = True,
201
+ emb_init_std: Optional[float] = None,
202
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
203
+ verbose: int = 0,
204
+ **kwargs,
205
+ ):
206
+ del kwargs
207
+ if init_std is None:
208
+ raise ValueError(
209
+ "You must set model.init_config['init_std'] to a float value to use the default initialization scheme."
210
+ )
211
+ _normal_param_init_fn_(
212
+ module=module,
213
+ std=init_std,
214
+ d_model=d_model,
215
+ n_layers=n_layers,
216
+ init_div_is_residual=init_div_is_residual,
217
+ emb_init_std=emb_init_std,
218
+ emb_init_uniform_lim=emb_init_uniform_lim,
219
+ verbose=verbose,
220
+ )
221
+
222
+
223
+ def small_param_init_fn_(
224
+ module: nn.Module,
225
+ n_layers: int,
226
+ d_model: int,
227
+ init_div_is_residual: Union[int, float, str, bool] = True,
228
+ emb_init_std: Optional[float] = None,
229
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
230
+ verbose: int = 0,
231
+ **kwargs,
232
+ ):
233
+ del kwargs
234
+ std = math.sqrt(2 / (5 * d_model))
235
+ _normal_param_init_fn_(
236
+ module=module,
237
+ std=std,
238
+ d_model=d_model,
239
+ n_layers=n_layers,
240
+ init_div_is_residual=init_div_is_residual,
241
+ emb_init_std=emb_init_std,
242
+ emb_init_uniform_lim=emb_init_uniform_lim,
243
+ verbose=verbose,
244
+ )
245
+
246
+
247
+ def neox_param_init_fn_(
248
+ module: nn.Module,
249
+ n_layers: int,
250
+ d_model: int,
251
+ emb_init_std: Optional[float] = None,
252
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
253
+ verbose: int = 0,
254
+ **kwargs,
255
+ ):
256
+ """From section 2.3.1 of GPT-NeoX-20B:
257
+
258
+ An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
259
+ see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151
260
+ and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py
261
+ """
262
+ del kwargs
263
+ residual_div = n_layers / math.sqrt(10)
264
+ if verbose > 1:
265
+ warnings.warn(f"setting init_div_is_residual to {residual_div}")
266
+ small_param_init_fn_(
267
+ module=module,
268
+ d_model=d_model,
269
+ n_layers=n_layers,
270
+ init_div_is_residual=residual_div,
271
+ emb_init_std=emb_init_std,
272
+ emb_init_uniform_lim=emb_init_uniform_lim,
273
+ verbose=verbose,
274
+ )
275
+
276
+
277
+ def kaiming_uniform_param_init_fn_(
278
+ module: nn.Module,
279
+ n_layers: int,
280
+ d_model: Optional[int] = None,
281
+ init_div_is_residual: Union[int, float, str, bool] = True,
282
+ emb_init_std: Optional[float] = None,
283
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
284
+ init_gain: float = 0,
285
+ fan_mode: str = "fan_in",
286
+ init_nonlinearity: str = "leaky_relu",
287
+ verbose: int = 0,
288
+ **kwargs,
289
+ ):
290
+ del kwargs
291
+ if verbose > 1:
292
+ warnings.warn(
293
+ f"Using nn.init.kaiming_uniform_ init fn with parameters: "
294
+ + f"a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}"
295
+ )
296
+ kaiming_uniform_ = partial(
297
+ nn.init.kaiming_uniform_,
298
+ a=init_gain,
299
+ mode=fan_mode,
300
+ nonlinearity=init_nonlinearity,
301
+ )
302
+ generic_param_init_fn_(
303
+ module=module,
304
+ init_fn_=kaiming_uniform_,
305
+ d_model=d_model,
306
+ n_layers=n_layers,
307
+ init_div_is_residual=init_div_is_residual,
308
+ emb_init_std=emb_init_std,
309
+ emb_init_uniform_lim=emb_init_uniform_lim,
310
+ verbose=verbose,
311
+ )
312
+
313
+
314
+ def kaiming_normal_param_init_fn_(
315
+ module: nn.Module,
316
+ n_layers: int,
317
+ d_model: Optional[int] = None,
318
+ init_div_is_residual: Union[int, float, str, bool] = True,
319
+ emb_init_std: Optional[float] = None,
320
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
321
+ init_gain: float = 0,
322
+ fan_mode: str = "fan_in",
323
+ init_nonlinearity: str = "leaky_relu",
324
+ verbose: int = 0,
325
+ **kwargs,
326
+ ):
327
+ del kwargs
328
+ if verbose > 1:
329
+ warnings.warn(
330
+ f"Using nn.init.kaiming_normal_ init fn with parameters: "
331
+ + f"a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}"
332
+ )
333
+ kaiming_normal_ = partial(
334
+ torch.nn.init.kaiming_normal_,
335
+ a=init_gain,
336
+ mode=fan_mode,
337
+ nonlinearity=init_nonlinearity,
338
+ )
339
+ generic_param_init_fn_(
340
+ module=module,
341
+ init_fn_=kaiming_normal_,
342
+ d_model=d_model,
343
+ n_layers=n_layers,
344
+ init_div_is_residual=init_div_is_residual,
345
+ emb_init_std=emb_init_std,
346
+ emb_init_uniform_lim=emb_init_uniform_lim,
347
+ verbose=verbose,
348
+ )
349
+
350
+
351
+ def xavier_uniform_param_init_fn_(
352
+ module: nn.Module,
353
+ n_layers: int,
354
+ d_model: Optional[int] = None,
355
+ init_div_is_residual: Union[int, float, str, bool] = True,
356
+ emb_init_std: Optional[float] = None,
357
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
358
+ init_gain: float = 0,
359
+ verbose: int = 0,
360
+ **kwargs,
361
+ ):
362
+ del kwargs
363
+ xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
364
+ if verbose > 1:
365
+ warnings.warn(
366
+ f"Using torch.nn.init.xavier_uniform_ init fn with parameters: "
367
+ + f"gain={init_gain}"
368
+ )
369
+ generic_param_init_fn_(
370
+ module=module,
371
+ init_fn_=xavier_uniform_,
372
+ d_model=d_model,
373
+ n_layers=n_layers,
374
+ init_div_is_residual=init_div_is_residual,
375
+ emb_init_std=emb_init_std,
376
+ emb_init_uniform_lim=emb_init_uniform_lim,
377
+ verbose=verbose,
378
+ )
379
+
380
+
381
+ def xavier_normal_param_init_fn_(
382
+ module: nn.Module,
383
+ n_layers: int,
384
+ d_model: Optional[int] = None,
385
+ init_div_is_residual: Union[int, float, str, bool] = True,
386
+ emb_init_std: Optional[float] = None,
387
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
388
+ init_gain: float = 0,
389
+ verbose: int = 0,
390
+ **kwargs,
391
+ ):
392
+ xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
393
+ if verbose > 1:
394
+ warnings.warn(
395
+ f"Using torch.nn.init.xavier_normal_ init fn with parameters: "
396
+ + f"gain={init_gain}"
397
+ )
398
+ generic_param_init_fn_(
399
+ module=module,
400
+ init_fn_=xavier_normal_,
401
+ d_model=d_model,
402
+ n_layers=n_layers,
403
+ init_div_is_residual=init_div_is_residual,
404
+ emb_init_std=emb_init_std,
405
+ emb_init_uniform_lim=emb_init_uniform_lim,
406
+ verbose=verbose,
407
+ )
408
+
409
+
410
+ MODEL_INIT_REGISTRY = {
411
+ "default_": torch_default_param_init_fn_,
412
+ "baseline_": baseline_param_init_fn_,
413
+ "kaiming_uniform_": kaiming_uniform_param_init_fn_,
414
+ "kaiming_normal_": kaiming_normal_param_init_fn_,
415
+ "neox_init_": neox_param_init_fn_,
416
+ "small_init_": small_param_init_fn_,
417
+ "xavier_uniform_": xavier_uniform_param_init_fn_,
418
+ "xavier_normal_": xavier_normal_param_init_fn_,
419
+ }
lisa_on_cuda/llava/model/llava_arch.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from abc import ABC, abstractmethod
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from lisa_on_cuda.utils.utils import IGNORE_INDEX, IMAGE_TOKEN_INDEX
22
+
23
+ from .multimodal_encoder.builder import build_vision_tower
24
+
25
+
26
+ class LlavaMetaModel:
27
+ def __init__(self, config):
28
+ super(LlavaMetaModel, self).__init__(config)
29
+
30
+ if hasattr(config, "mm_vision_tower"):
31
+ self.vision_tower = build_vision_tower(config, delay_load=True)
32
+ self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
33
+
34
+ def get_vision_tower(self):
35
+ vision_tower = getattr(self, "vision_tower", None)
36
+ if type(vision_tower) is list:
37
+ vision_tower = vision_tower[0]
38
+ return vision_tower
39
+
40
+ def initialize_vision_modules(self, model_args, fsdp=None):
41
+ vision_tower = model_args.vision_tower
42
+ mm_vision_select_layer = model_args.mm_vision_select_layer
43
+ mm_vision_select_feature = model_args.mm_vision_select_feature
44
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
45
+
46
+ self.config.mm_vision_tower = vision_tower
47
+
48
+ vision_tower = build_vision_tower(model_args)
49
+
50
+ if fsdp is not None and len(fsdp) > 0:
51
+ self.vision_tower = [vision_tower]
52
+ else:
53
+ self.vision_tower = vision_tower
54
+
55
+ self.config.use_mm_proj = True
56
+ self.config.mm_hidden_size = vision_tower.hidden_size
57
+ self.config.mm_vision_select_layer = mm_vision_select_layer
58
+ self.config.mm_vision_select_feature = mm_vision_select_feature
59
+
60
+ if not hasattr(self, "mm_projector"):
61
+ self.mm_projector = nn.Linear(
62
+ self.config.mm_hidden_size, self.config.hidden_size
63
+ )
64
+
65
+ if pretrain_mm_mlp_adapter is not None:
66
+ mm_projector_weights = torch.load(
67
+ pretrain_mm_mlp_adapter, map_location="cpu"
68
+ )
69
+
70
+ def get_w(weights, keyword):
71
+ return {
72
+ k.split(keyword + ".")[1]: v
73
+ for k, v in weights.items()
74
+ if keyword in k
75
+ }
76
+
77
+ self.mm_projector.load_state_dict(
78
+ get_w(mm_projector_weights, "mm_projector")
79
+ )
80
+
81
+
82
+ class LlavaMetaForCausalLM(ABC):
83
+ @abstractmethod
84
+ def get_model(self):
85
+ pass
86
+
87
+ def get_vision_tower(self):
88
+ return self.get_model().get_vision_tower()
89
+
90
+ def encode_images(self, images):
91
+ image_features = self.get_model().get_vision_tower()(images)
92
+ image_features = self.get_model().mm_projector(image_features)
93
+ return image_features
94
+
95
+ def prepare_inputs_labels_for_multimodal(
96
+ self, input_ids, attention_mask, past_key_values, labels, images
97
+ ):
98
+ vision_tower = self.get_vision_tower()
99
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
100
+ if (
101
+ past_key_values is not None
102
+ and vision_tower is not None
103
+ and images is not None
104
+ and input_ids.shape[1] == 1
105
+ ):
106
+ attention_mask = torch.ones(
107
+ (attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1),
108
+ dtype=attention_mask.dtype,
109
+ device=attention_mask.device,
110
+ )
111
+ return input_ids, attention_mask, past_key_values, None, labels
112
+
113
+ if type(images) is list or images.ndim == 5:
114
+ concat_images = torch.cat([image for image in images], dim=0)
115
+ image_features = self.encode_images(concat_images)
116
+ split_sizes = [image.shape[0] for image in images]
117
+ image_features = torch.split(image_features, split_sizes, dim=0)
118
+ image_features = [x.flatten(0, 1) for x in image_features]
119
+ else:
120
+ image_features = self.encode_images(images)
121
+
122
+ new_input_embeds = []
123
+ new_labels = [] if labels is not None else None
124
+ cur_image_idx = 0
125
+ for batch_idx, cur_input_ids in enumerate(input_ids):
126
+ if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
127
+ # multimodal LLM, but the current sample is not multimodal
128
+ cur_input_embeds = self.get_model().embed_tokens(cur_input_ids)
129
+ cur_input_embeds = (
130
+ cur_input_embeds
131
+ + (
132
+ 0.0 * self.get_model().mm_projector(vision_tower.dummy_feature)
133
+ ).sum()
134
+ )
135
+ new_input_embeds.append(cur_input_embeds)
136
+ if labels is not None:
137
+ new_labels.append(labels[batch_idx])
138
+ cur_image_idx += 1
139
+ continue
140
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
141
+ cur_new_input_embeds = []
142
+ if labels is not None:
143
+ cur_labels = labels[batch_idx]
144
+ cur_new_labels = []
145
+ assert cur_labels.shape == cur_input_ids.shape
146
+ while image_token_indices.numel() > 0:
147
+ cur_image_features = image_features[cur_image_idx]
148
+ image_token_start = image_token_indices[0]
149
+ if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
150
+ self.config, "mm_use_im_start_end", False
151
+ ):
152
+ cur_new_input_embeds.append(
153
+ self.get_model()
154
+ .embed_tokens(cur_input_ids[: image_token_start - 1])
155
+ .detach()
156
+ )
157
+ cur_new_input_embeds.append(
158
+ self.get_model().embed_tokens(
159
+ cur_input_ids[image_token_start - 1 : image_token_start]
160
+ )
161
+ )
162
+ cur_new_input_embeds.append(cur_image_features)
163
+ cur_new_input_embeds.append(
164
+ self.get_model().embed_tokens(
165
+ cur_input_ids[image_token_start + 1 : image_token_start + 2]
166
+ )
167
+ )
168
+ if labels is not None:
169
+ cur_new_labels.append(cur_labels[:image_token_start])
170
+ cur_new_labels.append(
171
+ torch.full(
172
+ (cur_image_features.shape[0],),
173
+ IGNORE_INDEX,
174
+ device=labels.device,
175
+ dtype=labels.dtype,
176
+ )
177
+ )
178
+ cur_new_labels.append(
179
+ cur_labels[image_token_start : image_token_start + 1]
180
+ )
181
+ cur_labels = cur_labels[image_token_start + 2 :]
182
+ elif getattr(self.config, "mm_use_im_start_end", False):
183
+ cur_new_input_embeds.append(
184
+ self.get_model().embed_tokens(cur_input_ids[:image_token_start])
185
+ )
186
+ cur_new_input_embeds.append(cur_image_features)
187
+ cur_new_input_embeds.append(
188
+ self.get_model().embed_tokens(
189
+ cur_input_ids[image_token_start + 1 : image_token_start + 2]
190
+ )
191
+ )
192
+ if labels is not None:
193
+ cur_new_labels.append(cur_labels[:image_token_start])
194
+ cur_new_labels.append(
195
+ torch.full(
196
+ (cur_image_features.shape[0],),
197
+ IGNORE_INDEX,
198
+ device=labels.device,
199
+ dtype=labels.dtype,
200
+ )
201
+ )
202
+ cur_new_labels.append(
203
+ cur_labels[image_token_start + 1 : image_token_start + 2]
204
+ )
205
+ cur_labels = cur_labels[image_token_start + 2 :]
206
+ else:
207
+ cur_new_input_embeds.append(
208
+ self.get_model().embed_tokens(cur_input_ids[:image_token_start])
209
+ )
210
+ cur_new_input_embeds.append(cur_image_features)
211
+ if labels is not None:
212
+ cur_new_labels.append(cur_labels[:image_token_start])
213
+ cur_new_labels.append(
214
+ torch.full(
215
+ (cur_image_features.shape[0],),
216
+ IGNORE_INDEX,
217
+ device=labels.device,
218
+ dtype=labels.dtype,
219
+ )
220
+ )
221
+ cur_labels = cur_labels[image_token_start + 1 :]
222
+ cur_image_idx += 1
223
+ if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
224
+ self.config, "mm_use_im_start_end", False
225
+ ):
226
+ cur_input_ids = cur_input_ids[image_token_start + 2 :]
227
+ elif getattr(self.config, "mm_use_im_start_end", False):
228
+ cur_input_ids = cur_input_ids[image_token_start + 2 :]
229
+ else:
230
+ cur_input_ids = cur_input_ids[image_token_start + 1 :]
231
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
232
+ if cur_input_ids.numel() > 0:
233
+ if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
234
+ self.config, "mm_use_im_start_end", False
235
+ ):
236
+ cur_new_input_embeds.append(
237
+ self.get_model().embed_tokens(cur_input_ids).detach()
238
+ )
239
+ elif getattr(self.config, "mm_use_im_start_end", False):
240
+ cur_new_input_embeds.append(
241
+ self.get_model().embed_tokens(cur_input_ids)
242
+ )
243
+ else:
244
+ cur_new_input_embeds.append(
245
+ self.get_model().embed_tokens(cur_input_ids)
246
+ )
247
+ if labels is not None:
248
+ cur_new_labels.append(cur_labels)
249
+ cur_new_input_embeds = [
250
+ x.to(device=self.device) for x in cur_new_input_embeds
251
+ ]
252
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
253
+ new_input_embeds.append(cur_new_input_embeds)
254
+ if labels is not None:
255
+ cur_new_labels = torch.cat(cur_new_labels, dim=0)
256
+ new_labels.append(cur_new_labels)
257
+
258
+ if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
259
+ max_len = max(x.shape[0] for x in new_input_embeds)
260
+
261
+ new_input_embeds_align = []
262
+ for cur_new_embed in new_input_embeds:
263
+ cur_new_embed = torch.cat(
264
+ (
265
+ cur_new_embed,
266
+ torch.zeros(
267
+ (max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]),
268
+ dtype=cur_new_embed.dtype,
269
+ device=cur_new_embed.device,
270
+ ),
271
+ ),
272
+ dim=0,
273
+ )
274
+ new_input_embeds_align.append(cur_new_embed)
275
+ new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
276
+
277
+ if labels is not None:
278
+ new_labels_align = []
279
+ _new_labels = new_labels
280
+ for cur_new_label in new_labels:
281
+ cur_new_label = torch.cat(
282
+ (
283
+ cur_new_label,
284
+ torch.full(
285
+ (max_len - cur_new_label.shape[0],),
286
+ IGNORE_INDEX,
287
+ dtype=cur_new_label.dtype,
288
+ device=cur_new_label.device,
289
+ ),
290
+ ),
291
+ dim=0,
292
+ )
293
+ new_labels_align.append(cur_new_label)
294
+ new_labels = torch.stack(new_labels_align, dim=0)
295
+
296
+ if attention_mask is not None:
297
+ new_attention_mask = []
298
+ for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(
299
+ attention_mask, _new_labels, new_labels
300
+ ):
301
+ new_attn_mask_pad_left = torch.full(
302
+ (cur_new_labels.shape[0] - labels.shape[1],),
303
+ True,
304
+ dtype=attention_mask.dtype,
305
+ device=attention_mask.device,
306
+ )
307
+ new_attn_mask_pad_right = torch.full(
308
+ (cur_new_labels_align.shape[0] - cur_new_labels.shape[0],),
309
+ False,
310
+ dtype=attention_mask.dtype,
311
+ device=attention_mask.device,
312
+ )
313
+ cur_new_attention_mask = torch.cat(
314
+ (
315
+ new_attn_mask_pad_left,
316
+ cur_attention_mask,
317
+ new_attn_mask_pad_right,
318
+ ),
319
+ dim=0,
320
+ )
321
+ new_attention_mask.append(cur_new_attention_mask)
322
+ attention_mask = torch.stack(new_attention_mask, dim=0)
323
+ assert attention_mask.shape == new_labels.shape
324
+ else:
325
+ new_input_embeds = torch.stack(new_input_embeds, dim=0)
326
+ if labels is not None:
327
+ new_labels = torch.stack(new_labels, dim=0)
328
+
329
+ if attention_mask is not None:
330
+ new_attn_mask_pad_left = torch.full(
331
+ (
332
+ attention_mask.shape[0],
333
+ new_input_embeds.shape[1] - input_ids.shape[1],
334
+ ),
335
+ True,
336
+ dtype=attention_mask.dtype,
337
+ device=attention_mask.device,
338
+ )
339
+ attention_mask = torch.cat(
340
+ (new_attn_mask_pad_left, attention_mask), dim=1
341
+ )
342
+ assert attention_mask.shape == new_input_embeds.shape[:2]
343
+
344
+ return None, attention_mask, past_key_values, new_input_embeds, new_labels
345
+
346
+ # def initialize_vision_tokenizer(self, model_args, tokenizer):
347
+ def initialize_vision_tokenizer(self, model_args, num_new_tokens):
348
+ # if model_args.mm_use_im_patch_token:
349
+ # tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
350
+ # self.resize_token_embeddings(len(tokenizer))
351
+
352
+ if model_args.mm_use_im_start_end:
353
+ # num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
354
+ # self.resize_token_embeddings(len(tokenizer))
355
+
356
+ # if num_new_tokens > 0:
357
+ # input_embeddings = self.get_input_embeddings().weight.data
358
+ # output_embeddings = self.get_output_embeddings().weight.data
359
+
360
+ # input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
361
+ # dim=0, keepdim=True)
362
+ # output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
363
+ # dim=0, keepdim=True)
364
+
365
+ # input_embeddings[-num_new_tokens:] = input_embeddings_avg
366
+ # output_embeddings[-num_new_tokens:] = output_embeddings_avg
367
+
368
+ if model_args.tune_mm_mlp_adapter:
369
+ for p in self.get_input_embeddings().parameters():
370
+ p.requires_grad = True
371
+ for p in self.get_output_embeddings().parameters():
372
+ p.requires_grad = False
373
+
374
+ if model_args.pretrain_mm_mlp_adapter:
375
+ mm_projector_weights = torch.load(
376
+ model_args.pretrain_mm_mlp_adapter, map_location="cpu"
377
+ )
378
+ embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
379
+ assert num_new_tokens == 2
380
+ if input_embeddings.shape == embed_tokens_weight.shape:
381
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[
382
+ -num_new_tokens:
383
+ ]
384
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
385
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
386
+ else:
387
+ raise ValueError(
388
+ f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}."
389
+ )
390
+ elif model_args.mm_use_im_patch_token:
391
+ if model_args.tune_mm_mlp_adapter:
392
+ for p in self.get_input_embeddings().parameters():
393
+ p.requires_grad = False
394
+ for p in self.get_output_embeddings().parameters():
395
+ p.requires_grad = False
lisa_on_cuda/llava/model/make_delta.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from .utils import auto_upgrade
9
+ from tqdm import tqdm
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
+
12
+
13
+ def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
14
+ print("Loading base model")
15
+ base = AutoModelForCausalLM.from_pretrained(
16
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
17
+ )
18
+
19
+ print("Loading target model")
20
+ auto_upgrade(target_model_path)
21
+ target = AutoModelForCausalLM.from_pretrained(
22
+ target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
23
+ )
24
+
25
+ print("Calculating delta")
26
+ for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
27
+ if name not in base.state_dict():
28
+ assert name in [
29
+ "model.mm_projector.weight",
30
+ "model.mm_projector.bias",
31
+ ], f"{name} not in base model"
32
+ continue
33
+ if param.data.shape == base.state_dict()[name].shape:
34
+ param.data -= base.state_dict()[name]
35
+ else:
36
+ assert name in [
37
+ "model.embed_tokens.weight",
38
+ "lm_head.weight",
39
+ ], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
40
+ bparam = base.state_dict()[name]
41
+ param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam
42
+
43
+ print("Saving delta")
44
+ if hub_repo_id:
45
+ kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
46
+ else:
47
+ kwargs = {}
48
+ target.save_pretrained(delta_path, **kwargs)
49
+ target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
50
+ target_tokenizer.save_pretrained(delta_path, **kwargs)
51
+
52
+
53
+ if __name__ == "__main__":
54
+ parser = argparse.ArgumentParser()
55
+ parser.add_argument("--base-model-path", type=str, required=True)
56
+ parser.add_argument("--target-model-path", type=str, required=True)
57
+ parser.add_argument("--delta-path", type=str, required=True)
58
+ parser.add_argument("--hub-repo-id", type=str, default=None)
59
+ args = parser.parse_args()
60
+
61
+ make_delta(
62
+ args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id
63
+ )
lisa_on_cuda/llava/model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .clip_encoder import CLIPVisionTower
2
+
3
+
4
+ def build_vision_tower(vision_tower_cfg, **kwargs):
5
+ vision_tower = getattr(
6
+ vision_tower_cfg,
7
+ "mm_vision_tower",
8
+ getattr(vision_tower_cfg, "vision_tower", None),
9
+ )
10
+ if (
11
+ vision_tower.startswith("openai")
12
+ or vision_tower.startswith("laion")
13
+ or "clip" in vision_tower
14
+ ):
15
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
16
+
17
+ raise ValueError(f"Unknown vision tower: {vision_tower}")
lisa_on_cuda/llava/model/multimodal_encoder/clip_encoder.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModel
4
+
5
+
6
+ class CLIPVisionTower(nn.Module):
7
+ def __init__(self, vision_tower, args, delay_load=False):
8
+ super().__init__()
9
+
10
+ self.is_loaded = False
11
+
12
+ self.vision_tower_name = vision_tower
13
+ self.select_layer = args.mm_vision_select_layer
14
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
15
+
16
+ if not delay_load:
17
+ self.load_model()
18
+ else:
19
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
20
+
21
+ def load_model(self):
22
+ self.image_processor = CLIPImageProcessor.from_pretrained(
23
+ self.vision_tower_name
24
+ )
25
+ self.vision_tower = CLIPVisionModel.from_pretrained(
26
+ self.vision_tower_name, low_cpu_mem_usage=True
27
+ )
28
+ self.vision_tower.requires_grad_(False)
29
+ self.is_loaded = True
30
+
31
+ def feature_select(self, image_forward_outs):
32
+ image_features = image_forward_outs.hidden_states[self.select_layer]
33
+ if self.select_feature == "patch":
34
+ image_features = image_features[:, 1:]
35
+ elif self.select_feature == "cls_patch":
36
+ image_features = image_features
37
+ else:
38
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
39
+ return image_features
40
+
41
+ @torch.no_grad()
42
+ def forward(self, images):
43
+ if type(images) is list:
44
+ image_features = []
45
+ for image in images:
46
+ image_forward_out = self.vision_tower(
47
+ image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
48
+ output_hidden_states=True,
49
+ )
50
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
51
+ image_features.append(image_feature)
52
+ else:
53
+ image_forward_outs = self.vision_tower(
54
+ images.to(device=self.device, dtype=self.dtype),
55
+ output_hidden_states=True,
56
+ )
57
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
58
+
59
+ torch.cuda.empty_cache()
60
+ return image_features
61
+
62
+ @property
63
+ def dummy_feature(self):
64
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
65
+
66
+ @property
67
+ def dtype(self):
68
+ return self.vision_tower.dtype
69
+
70
+ @property
71
+ def device(self):
72
+ return self.vision_tower.device
73
+
74
+ @property
75
+ def config(self):
76
+ if self.is_loaded:
77
+ return self.vision_tower.config
78
+ else:
79
+ return self.cfg_only
80
+
81
+ @property
82
+ def hidden_size(self):
83
+ return self.config.hidden_size
84
+
85
+ @property
86
+ def num_patches(self):
87
+ return (self.config.image_size // self.config.patch_size) ** 2
lisa_on_cuda/llava/model/utils.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ from transformers import AutoConfig
4
+
5
+
6
+ def auto_upgrade(config):
7
+ cfg = AutoConfig.from_pretrained(config)
8
+ if "llava" in config and "llava" not in cfg.model_type:
9
+ assert cfg.model_type == "llama"
10
+ print(
11
+ "You are using newer LLaVA code base, while the checkpoint of v0 is from older code base."
12
+ )
13
+ print(
14
+ "You must upgrade the checkpoint to the new code base (this can be done automatically)."
15
+ )
16
+ confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
17
+ if confirm.lower() in ["y", "yes"]:
18
+ print("Upgrading checkpoint...")
19
+ assert len(cfg.architectures) == 1
20
+ setattr(cfg.__class__, "model_type", "llava")
21
+ cfg.architectures[0] = "LlavaLlamaForCausalLM"
22
+ cfg.save_pretrained(config)
23
+ print("Checkpoint upgraded.")
24
+ else:
25
+ print("Checkpoint upgrade aborted.")
26
+ sys.exit(1)
lisa_on_cuda/llava/train/llama_flash_attn_monkey_patch.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List, Optional, Tuple
3
+
4
+ import torch
5
+ import transformers
6
+ from einops import rearrange
7
+ from torch import nn
8
+ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
9
+
10
+ try:
11
+ from flash_attn.flash_attn_interface import \
12
+ flash_attn_unpadded_qkvpacked_func
13
+ except ImportError:
14
+ from flash_attn.flash_attn_interface import (
15
+ flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func,
16
+ )
17
+
18
+ from flash_attn.bert_padding import pad_input, unpad_input
19
+
20
+
21
+ def forward(
22
+ self,
23
+ hidden_states: torch.Tensor,
24
+ attention_mask: Optional[torch.Tensor] = None,
25
+ position_ids: Optional[torch.Tensor] = None,
26
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
27
+ output_attentions: bool = False,
28
+ use_cache: bool = False,
29
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
30
+ """Input shape: Batch x Time x Channel
31
+
32
+ attention_mask: [bsz, q_len]
33
+ """
34
+ bsz, q_len, _ = hidden_states.size()
35
+
36
+ query_states = (
37
+ self.q_proj(hidden_states)
38
+ .view(bsz, q_len, self.num_heads, self.head_dim)
39
+ .transpose(1, 2)
40
+ )
41
+ key_states = (
42
+ self.k_proj(hidden_states)
43
+ .view(bsz, q_len, self.num_heads, self.head_dim)
44
+ .transpose(1, 2)
45
+ )
46
+ value_states = (
47
+ self.v_proj(hidden_states)
48
+ .view(bsz, q_len, self.num_heads, self.head_dim)
49
+ .transpose(1, 2)
50
+ )
51
+ # [bsz, q_len, nh, hd]
52
+ # [bsz, nh, q_len, hd]
53
+
54
+ kv_seq_len = key_states.shape[-2]
55
+ assert past_key_value is None, "past_key_value is not supported"
56
+
57
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
58
+ query_states, key_states = apply_rotary_pos_emb(
59
+ query_states, key_states, cos, sin, position_ids
60
+ )
61
+ # [bsz, nh, t, hd]
62
+ assert not output_attentions, "output_attentions is not supported"
63
+ assert not use_cache, "use_cache is not supported"
64
+
65
+ # Flash attention codes from
66
+ # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
67
+
68
+ # transform the data into the format required by flash attention
69
+ qkv = torch.stack(
70
+ [query_states, key_states, value_states], dim=2
71
+ ) # [bsz, nh, 3, q_len, hd]
72
+ qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
73
+ # We have disabled _prepare_decoder_attention_mask in LlamaModel
74
+ # the attention_mask should be the same as the key_padding_mask
75
+ key_padding_mask = attention_mask
76
+
77
+ if key_padding_mask is None:
78
+ qkv = rearrange(qkv, "b s ... -> (b s) ...")
79
+ max_s = q_len
80
+ cu_q_lens = torch.arange(
81
+ 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
82
+ )
83
+ output = flash_attn_unpadded_qkvpacked_func(
84
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
85
+ )
86
+ output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
87
+ else:
88
+ nheads = qkv.shape[-2]
89
+ x = rearrange(qkv, "b s three h d -> b s (three h d)")
90
+ x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
91
+ x_unpad = rearrange(
92
+ x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
93
+ )
94
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
95
+ x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
96
+ )
97
+ output = rearrange(
98
+ pad_input(
99
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
100
+ ),
101
+ "b s (h d) -> b s h d",
102
+ h=nheads,
103
+ )
104
+ return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
105
+
106
+
107
+ # Disable the transformation of the attention mask in LlamaModel as the flash attention
108
+ # requires the attention mask to be the same as the key_padding_mask
109
+ def _prepare_decoder_attention_mask(
110
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
111
+ ):
112
+ # [bsz, seq_len]
113
+ return attention_mask
114
+
115
+
116
+ def replace_llama_attn_with_flash_attn():
117
+ cuda_major, cuda_minor = torch.cuda.get_device_capability()
118
+ if cuda_major < 8:
119
+ logging.warning(
120
+ "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
121
+ "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
122
+ )
123
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
124
+ _prepare_decoder_attention_mask
125
+ )
126
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
lisa_on_cuda/llava/train/llava_trainer.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from transformers import Trainer
6
+
7
+
8
+ def maybe_zero_3(param, ignore_status=False, name=None):
9
+ from deepspeed import zero
10
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
11
+
12
+ if hasattr(param, "ds_id"):
13
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
14
+ if not ignore_status:
15
+ print(name, "no ignore status")
16
+ with zero.GatheredParameters([param]):
17
+ param = param.data.detach().cpu().clone()
18
+ else:
19
+ param = param.detach().cpu().clone()
20
+ return param
21
+
22
+
23
+ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
24
+ to_return = {
25
+ k: t
26
+ for k, t in named_params
27
+ if any(key_match in k for key_match in keys_to_match)
28
+ }
29
+ to_return = {
30
+ k: maybe_zero_3(v, ignore_status=True, name=k).cpu()
31
+ for k, v in to_return.items()
32
+ }
33
+ return to_return
34
+
35
+
36
+ class LLaVATrainer(Trainer):
37
+ def _save_checkpoint(self, model, trial, metrics=None):
38
+ if getattr(self.args, "tune_mm_mlp_adapter", False):
39
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
40
+
41
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
42
+
43
+ run_dir = self._get_output_dir(trial=trial)
44
+ output_dir = os.path.join(run_dir, checkpoint_folder)
45
+
46
+ # Only save Adapter
47
+ keys_to_match = ["mm_projector"]
48
+ if getattr(self.args, "use_im_start_end", False):
49
+ keys_to_match.extend(["embed_tokens", "embed_in"])
50
+
51
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(
52
+ self.model.named_parameters(), keys_to_match
53
+ )
54
+
55
+ if self.args.local_rank == 0 or self.args.local_rank == -1:
56
+ self.model.config.save_pretrained(output_dir)
57
+ torch.save(
58
+ weight_to_save, os.path.join(output_dir, f"mm_projector.bin")
59
+ )
60
+ else:
61
+ super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics)
62
+
63
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
64
+ if getattr(self.args, "tune_mm_mlp_adapter", False):
65
+ pass
66
+ else:
67
+ super(LLaVATrainer, self)._save(output_dir, state_dict)
lisa_on_cuda/llava/train/train.py ADDED
@@ -0,0 +1,1038 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2
+ # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3
+ # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import copy
18
+ import json
19
+ import logging
20
+ import os
21
+ import pathlib
22
+ from dataclasses import dataclass, field
23
+ from typing import Dict, List, Optional, Sequence
24
+
25
+ import torch
26
+ import transformers
27
+ from llava import conversation as conversation_lib
28
+ from llava.constants import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
29
+ DEFAULT_IMAGE_TOKEN, IGNORE_INDEX,
30
+ IMAGE_TOKEN_INDEX)
31
+ from llava.mm_utils import tokenizer_image_token
32
+ from llava.model import *
33
+ from llava.train.llava_trainer import LLaVATrainer
34
+ from PIL import Image
35
+ from torch.utils.data import Dataset
36
+
37
+ local_rank = None
38
+
39
+
40
+ def rank0_print(*args):
41
+ if local_rank == 0:
42
+ print(*args)
43
+
44
+
45
+ @dataclass
46
+ class ModelArguments:
47
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
48
+ version: Optional[str] = field(default="v0")
49
+ freeze_backbone: bool = field(default=False)
50
+ tune_mm_mlp_adapter: bool = field(default=False)
51
+ vision_tower: Optional[str] = field(default=None)
52
+ mm_vision_select_layer: Optional[int] = field(
53
+ default=-1
54
+ ) # default to the last layer
55
+ pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
56
+ mm_use_im_start_end: bool = field(default=False)
57
+ mm_use_im_patch_token: bool = field(default=True)
58
+ mm_vision_select_feature: Optional[str] = field(default="patch")
59
+
60
+
61
+ @dataclass
62
+ class DataArguments:
63
+ data_path: str = field(
64
+ default=None, metadata={"help": "Path to the training data."}
65
+ )
66
+ lazy_preprocess: bool = False
67
+ is_multimodal: bool = False
68
+ image_folder: Optional[str] = field(default=None)
69
+ image_aspect_ratio: str = "square"
70
+ image_grid_pinpoints: Optional[str] = field(default=None)
71
+
72
+
73
+ @dataclass
74
+ class TrainingArguments(transformers.TrainingArguments):
75
+ cache_dir: Optional[str] = field(default=None)
76
+ optim: str = field(default="adamw_torch")
77
+ remove_unused_columns: bool = field(default=False)
78
+ freeze_mm_mlp_adapter: bool = field(default=False)
79
+ mpt_attn_impl: Optional[str] = field(default="triton")
80
+ model_max_length: int = field(
81
+ default=512,
82
+ metadata={
83
+ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
84
+ },
85
+ )
86
+ double_quant: bool = field(
87
+ default=True,
88
+ metadata={
89
+ "help": "Compress the quantization statistics through double quantization."
90
+ },
91
+ )
92
+ quant_type: str = field(
93
+ default="nf4",
94
+ metadata={
95
+ "help": "Quantization data type to use. Should be one of `fp4` or `nf4`."
96
+ },
97
+ )
98
+ bits: int = field(default=16, metadata={"help": "How many bits to use."})
99
+ lora_enable: bool = False
100
+ lora_r: int = 64
101
+ lora_alpha: int = 16
102
+ lora_dropout: float = 0.05
103
+ lora_weight_path: str = ""
104
+ lora_bias: str = "none"
105
+
106
+
107
+ def maybe_zero_3(param, ignore_status=False, name=None):
108
+ from deepspeed import zero
109
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
110
+
111
+ if hasattr(param, "ds_id"):
112
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
113
+ if not ignore_status:
114
+ logging.warning(
115
+ f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}"
116
+ )
117
+ with zero.GatheredParameters([param]):
118
+ param = param.data.detach().cpu().clone()
119
+ else:
120
+ param = param.detach().cpu().clone()
121
+ return param
122
+
123
+
124
+ # Borrowed from peft.utils.get_peft_model_state_dict
125
+ def get_peft_state_maybe_zero_3(named_params, bias):
126
+ if bias == "none":
127
+ to_return = {k: t for k, t in named_params if "lora_" in k}
128
+ elif bias == "all":
129
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
130
+ elif bias == "lora_only":
131
+ to_return = {}
132
+ maybe_lora_bias = {}
133
+ lora_bias_names = set()
134
+ for k, t in named_params:
135
+ if "lora_" in k:
136
+ to_return[k] = t
137
+ bias_name = k.split("lora_")[0] + "bias"
138
+ lora_bias_names.add(bias_name)
139
+ elif "bias" in k:
140
+ maybe_lora_bias[k] = t
141
+ for k, t in maybe_lora_bias.items():
142
+ if bias_name in lora_bias_names:
143
+ to_return[bias_name] = t
144
+ else:
145
+ raise NotImplementedError
146
+ to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()}
147
+ return to_return
148
+
149
+
150
+ def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
151
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
152
+ if require_grad_only:
153
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
154
+ to_return = {
155
+ k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()
156
+ }
157
+ return to_return
158
+
159
+
160
+ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
161
+ to_return = {
162
+ k: t
163
+ for k, t in named_params
164
+ if any(key_match in k for key_match in keys_to_match)
165
+ }
166
+ to_return = {
167
+ k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()
168
+ }
169
+ return to_return
170
+
171
+
172
+ def find_all_linear_names(model):
173
+ cls = torch.nn.Linear
174
+ lora_module_names = set()
175
+ for name, module in model.named_modules():
176
+ if isinstance(module, cls):
177
+ names = name.split(".")
178
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
179
+
180
+ if "lm_head" in lora_module_names: # needed for 16-bit
181
+ lora_module_names.remove("lm_head")
182
+ return list(lora_module_names)
183
+
184
+
185
+ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
186
+ """Collects the state dict and dump to disk."""
187
+
188
+ if getattr(trainer.args, "tune_mm_mlp_adapter", False):
189
+ # Only save Adapter
190
+ keys_to_match = ["mm_projector"]
191
+ if getattr(trainer.args, "use_im_start_end", False):
192
+ keys_to_match.extend(["embed_tokens", "embed_in"])
193
+
194
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(
195
+ trainer.model.named_parameters(), keys_to_match
196
+ )
197
+ trainer.model.config.save_pretrained(output_dir)
198
+
199
+ current_folder = output_dir.split("/")[-1]
200
+ parent_folder = os.path.dirname(output_dir)
201
+ if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
202
+ if current_folder.startswith("checkpoint-"):
203
+ mm_projector_folder = os.path.join(parent_folder, "mm_projector")
204
+ os.makedirs(mm_projector_folder, exist_ok=True)
205
+ torch.save(
206
+ weight_to_save,
207
+ os.path.join(mm_projector_folder, f"{current_folder}.bin"),
208
+ )
209
+ else:
210
+ torch.save(
211
+ weight_to_save, os.path.join(output_dir, f"mm_projector.bin")
212
+ )
213
+ return
214
+
215
+ if trainer.deepspeed:
216
+ torch.cuda.synchronize()
217
+ trainer.save_model(output_dir)
218
+ return
219
+
220
+ state_dict = trainer.model.state_dict()
221
+ if trainer.args.should_save:
222
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
223
+ del state_dict
224
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
225
+
226
+
227
+ def smart_tokenizer_and_embedding_resize(
228
+ special_tokens_dict: Dict,
229
+ tokenizer: transformers.PreTrainedTokenizer,
230
+ model: transformers.PreTrainedModel,
231
+ ):
232
+ """Resize tokenizer and embedding.
233
+
234
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
235
+ """
236
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
237
+ model.resize_token_embeddings(len(tokenizer))
238
+
239
+ if num_new_tokens > 0:
240
+ input_embeddings = model.get_input_embeddings().weight.data
241
+ output_embeddings = model.get_output_embeddings().weight.data
242
+
243
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
244
+ dim=0, keepdim=True
245
+ )
246
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
247
+ dim=0, keepdim=True
248
+ )
249
+
250
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
251
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
252
+
253
+
254
+ def _tokenize_fn(
255
+ strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer
256
+ ) -> Dict:
257
+ """Tokenize a list of strings."""
258
+ tokenized_list = [
259
+ tokenizer(
260
+ text,
261
+ return_tensors="pt",
262
+ padding="longest",
263
+ max_length=tokenizer.model_max_length,
264
+ truncation=True,
265
+ )
266
+ for text in strings
267
+ ]
268
+ input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
269
+ input_ids_lens = labels_lens = [
270
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
271
+ for tokenized in tokenized_list
272
+ ]
273
+ return dict(
274
+ input_ids=input_ids,
275
+ labels=labels,
276
+ input_ids_lens=input_ids_lens,
277
+ labels_lens=labels_lens,
278
+ )
279
+
280
+
281
+ def _mask_targets(target, tokenized_lens, speakers):
282
+ # cur_idx = 0
283
+ cur_idx = tokenized_lens[0]
284
+ tokenized_lens = tokenized_lens[1:]
285
+ target[:cur_idx] = IGNORE_INDEX
286
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
287
+ if speaker == "human":
288
+ target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX
289
+ cur_idx += tokenized_len
290
+
291
+
292
+ def _add_speaker_and_signal(header, source, get_conversation=True):
293
+ """Add speaker and start/end signal on each round."""
294
+ BEGIN_SIGNAL = "### "
295
+ END_SIGNAL = "\n"
296
+ conversation = header
297
+ for sentence in source:
298
+ from_str = sentence["from"]
299
+ if from_str.lower() == "human":
300
+ from_str = conversation_lib.default_conversation.roles[0]
301
+ elif from_str.lower() == "gpt":
302
+ from_str = conversation_lib.default_conversation.roles[1]
303
+ else:
304
+ from_str = "unknown"
305
+ sentence["value"] = (
306
+ BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL
307
+ )
308
+ if get_conversation:
309
+ conversation += sentence["value"]
310
+ conversation += BEGIN_SIGNAL
311
+ return conversation
312
+
313
+
314
+ def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict:
315
+ is_multimodal = data_args.is_multimodal
316
+ if not is_multimodal:
317
+ return sources
318
+
319
+ for source in sources:
320
+ for sentence in source:
321
+ if DEFAULT_IMAGE_TOKEN in sentence["value"]:
322
+ sentence["value"] = (
323
+ sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip()
324
+ )
325
+ sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"]
326
+ sentence["value"] = sentence["value"].strip()
327
+ if "mmtag" in conversation_lib.default_conversation.version:
328
+ sentence["value"] = sentence["value"].replace(
329
+ DEFAULT_IMAGE_TOKEN,
330
+ "<Image>" + DEFAULT_IMAGE_TOKEN + "</Image>",
331
+ )
332
+ replace_token = DEFAULT_IMAGE_TOKEN
333
+ if data_args.mm_use_im_start_end:
334
+ replace_token = (
335
+ DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
336
+ )
337
+ sentence["value"] = sentence["value"].replace(
338
+ DEFAULT_IMAGE_TOKEN, replace_token
339
+ )
340
+
341
+ return sources
342
+
343
+
344
+ def preprocess_llama_2(
345
+ sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False
346
+ ) -> Dict:
347
+ conv = conversation_lib.default_conversation.copy()
348
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
349
+
350
+ # Apply prompt templates
351
+ conversations = []
352
+ for i, source in enumerate(sources):
353
+ if roles[source[0]["from"]] != conv.roles[0]:
354
+ # Skip the first one if it is not from human
355
+ source = source[1:]
356
+
357
+ conv.messages = []
358
+ for j, sentence in enumerate(source):
359
+ role = roles[sentence["from"]]
360
+ assert role == conv.roles[j % 2], f"{i}"
361
+ conv.append_message(role, sentence["value"])
362
+ conversations.append(conv.get_prompt())
363
+
364
+ # Tokenize conversations
365
+
366
+ if has_image:
367
+ input_ids = torch.stack(
368
+ [
369
+ tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
370
+ for prompt in conversations
371
+ ],
372
+ dim=0,
373
+ )
374
+ else:
375
+ input_ids = tokenizer(
376
+ conversations,
377
+ return_tensors="pt",
378
+ padding="longest",
379
+ max_length=tokenizer.model_max_length,
380
+ truncation=True,
381
+ ).input_ids
382
+
383
+ targets = input_ids.clone()
384
+
385
+ assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
386
+
387
+ # Mask targets
388
+ sep = "[/INST] "
389
+ for conversation, target in zip(conversations, targets):
390
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
391
+
392
+ rounds = conversation.split(conv.sep2)
393
+ cur_len = 1
394
+ target[:cur_len] = IGNORE_INDEX
395
+ for i, rou in enumerate(rounds):
396
+ if rou == "":
397
+ break
398
+
399
+ parts = rou.split(sep)
400
+ if len(parts) != 2:
401
+ break
402
+ parts[0] += sep
403
+
404
+ if has_image:
405
+ round_len = len(tokenizer_image_token(rou, tokenizer))
406
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
407
+ else:
408
+ round_len = len(tokenizer(rou).input_ids)
409
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
410
+
411
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
412
+
413
+ cur_len += round_len
414
+ target[cur_len:] = IGNORE_INDEX
415
+
416
+ if cur_len < tokenizer.model_max_length:
417
+ if cur_len != total_len:
418
+ target[:] = IGNORE_INDEX
419
+ print(
420
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
421
+ f" (ignored)"
422
+ )
423
+
424
+ return dict(
425
+ input_ids=input_ids,
426
+ labels=targets,
427
+ )
428
+
429
+
430
+ def preprocess_v1(
431
+ sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False
432
+ ) -> Dict:
433
+ conv = conversation_lib.default_conversation.copy()
434
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
435
+
436
+ # Apply prompt templates
437
+ conversations = []
438
+ for i, source in enumerate(sources):
439
+ if roles[source[0]["from"]] != conv.roles[0]:
440
+ # Skip the first one if it is not from human
441
+ source = source[1:]
442
+
443
+ conv.messages = []
444
+ for j, sentence in enumerate(source):
445
+ role = roles[sentence["from"]]
446
+ assert role == conv.roles[j % 2], f"{i}"
447
+ conv.append_message(role, sentence["value"])
448
+ conversations.append(conv.get_prompt())
449
+
450
+ # Tokenize conversations
451
+
452
+ if has_image:
453
+ input_ids = torch.stack(
454
+ [
455
+ tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
456
+ for prompt in conversations
457
+ ],
458
+ dim=0,
459
+ )
460
+ else:
461
+ input_ids = tokenizer(
462
+ conversations,
463
+ return_tensors="pt",
464
+ padding="longest",
465
+ max_length=tokenizer.model_max_length,
466
+ truncation=True,
467
+ ).input_ids
468
+
469
+ targets = input_ids.clone()
470
+
471
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
472
+
473
+ # Mask targets
474
+ sep = conv.sep + conv.roles[1] + ": "
475
+ for conversation, target in zip(conversations, targets):
476
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
477
+
478
+ rounds = conversation.split(conv.sep2)
479
+ cur_len = 1
480
+ target[:cur_len] = IGNORE_INDEX
481
+ for i, rou in enumerate(rounds):
482
+ if rou == "":
483
+ break
484
+
485
+ parts = rou.split(sep)
486
+ if len(parts) != 2:
487
+ break
488
+ parts[0] += sep
489
+
490
+ if has_image:
491
+ round_len = len(tokenizer_image_token(rou, tokenizer))
492
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
493
+ else:
494
+ round_len = len(tokenizer(rou).input_ids)
495
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
496
+
497
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
498
+
499
+ cur_len += round_len
500
+ target[cur_len:] = IGNORE_INDEX
501
+
502
+ if cur_len < tokenizer.model_max_length:
503
+ if cur_len != total_len:
504
+ target[:] = IGNORE_INDEX
505
+ print(
506
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
507
+ f" (ignored)"
508
+ )
509
+
510
+ return dict(
511
+ input_ids=input_ids,
512
+ labels=targets,
513
+ )
514
+
515
+
516
+ def preprocess_mpt(
517
+ sources,
518
+ tokenizer: transformers.PreTrainedTokenizer,
519
+ ) -> Dict:
520
+ conv = conversation_lib.default_conversation.copy()
521
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
522
+
523
+ # Apply prompt templates
524
+ conversations = []
525
+ for i, source in enumerate(sources):
526
+ if roles[source[0]["from"]] != conv.roles[0]:
527
+ # Skip the first one if it is not from human
528
+ source = source[1:]
529
+
530
+ conv.messages = []
531
+ for j, sentence in enumerate(source):
532
+ role = roles[sentence["from"]]
533
+ assert role == conv.roles[j % 2], f"{i}"
534
+ conv.append_message(role, sentence["value"])
535
+ conversations.append(conv.get_prompt())
536
+
537
+ # Tokenize conversations
538
+ input_ids = torch.stack(
539
+ [
540
+ tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
541
+ for prompt in conversations
542
+ ],
543
+ dim=0,
544
+ )
545
+ targets = input_ids.clone()
546
+ assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
547
+
548
+ # Mask targets
549
+ sep = conv.sep + conv.roles[1]
550
+ for conversation, target in zip(conversations, targets):
551
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
552
+
553
+ rounds = conversation.split(conv.sep)
554
+ re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
555
+ for conv_idx in range(3, len(rounds), 2):
556
+ re_rounds.append(
557
+ conv.sep.join(rounds[conv_idx : conv_idx + 2])
558
+ ) # user + gpt
559
+ cur_len = 0
560
+ target[:cur_len] = IGNORE_INDEX
561
+ for i, rou in enumerate(re_rounds):
562
+ if rou == "":
563
+ break
564
+
565
+ parts = rou.split(sep)
566
+ if len(parts) != 2:
567
+ break
568
+ parts[0] += sep
569
+ round_len = len(tokenizer_image_token(rou, tokenizer)) + len(
570
+ tokenizer_image_token(conv.sep, tokenizer)
571
+ )
572
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
573
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
574
+
575
+ cur_len += round_len
576
+ target[cur_len:] = IGNORE_INDEX
577
+
578
+ if cur_len < tokenizer.model_max_length:
579
+ if cur_len != total_len:
580
+ target[:] = IGNORE_INDEX
581
+ print(
582
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
583
+ f" (ignored)"
584
+ )
585
+
586
+ return dict(
587
+ input_ids=input_ids,
588
+ labels=targets,
589
+ )
590
+
591
+
592
+ def preprocess_plain(
593
+ sources: Sequence[str],
594
+ tokenizer: transformers.PreTrainedTokenizer,
595
+ ) -> Dict:
596
+ # add end signal and concatenate together
597
+ conversations = []
598
+ for source in sources:
599
+ assert len(source) == 2
600
+ assert DEFAULT_IMAGE_TOKEN in source[0]["value"]
601
+ source[0]["value"] = DEFAULT_IMAGE_TOKEN
602
+ conversation = (
603
+ source[0]["value"]
604
+ + source[1]["value"]
605
+ + conversation_lib.default_conversation.sep
606
+ )
607
+ conversations.append(conversation)
608
+ # tokenize conversations
609
+ input_ids = [
610
+ tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
611
+ for prompt in conversations
612
+ ]
613
+ targets = copy.deepcopy(input_ids)
614
+ for target, source in zip(targets, sources):
615
+ tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer))
616
+ target[:tokenized_len] = IGNORE_INDEX
617
+
618
+ return dict(input_ids=input_ids, labels=targets)
619
+
620
+
621
+ def preprocess(
622
+ sources: Sequence[str],
623
+ tokenizer: transformers.PreTrainedTokenizer,
624
+ has_image: bool = False,
625
+ ) -> Dict:
626
+ """
627
+ Given a list of sources, each is a conversation list. This transform:
628
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
629
+ 2. Concatenate conversations together;
630
+ 3. Tokenize the concatenated conversation;
631
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
632
+ """
633
+ if (
634
+ conversation_lib.default_conversation.sep_style
635
+ == conversation_lib.SeparatorStyle.PLAIN
636
+ ):
637
+ return preprocess_plain(sources, tokenizer)
638
+ if (
639
+ conversation_lib.default_conversation.sep_style
640
+ == conversation_lib.SeparatorStyle.LLAMA_2
641
+ ):
642
+ return preprocess_llama_2(sources, tokenizer, has_image=has_image)
643
+ if conversation_lib.default_conversation.version.startswith("v1"):
644
+ return preprocess_v1(sources, tokenizer, has_image=has_image)
645
+ if conversation_lib.default_conversation.version == "mpt":
646
+ return preprocess_mpt(sources, tokenizer)
647
+ # add end signal and concatenate together
648
+ conversations = []
649
+ for source in sources:
650
+ header = f"{conversation_lib.default_conversation.system}\n\n"
651
+ conversation = _add_speaker_and_signal(header, source)
652
+ conversations.append(conversation)
653
+
654
+ # tokenize conversations
655
+ def get_tokenize_len(prompts):
656
+ return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
657
+
658
+ if has_image:
659
+ input_ids = [
660
+ tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
661
+ for prompt in conversations
662
+ ]
663
+ else:
664
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
665
+ input_ids = conversations_tokenized["input_ids"]
666
+
667
+ targets = copy.deepcopy(input_ids)
668
+ for target, source in zip(targets, sources):
669
+ if has_image:
670
+ tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
671
+ else:
672
+ tokenized_lens = _tokenize_fn(
673
+ [header] + [s["value"] for s in source], tokenizer
674
+ )["input_ids_lens"]
675
+ speakers = [sentence["from"] for sentence in source]
676
+ _mask_targets(target, tokenized_lens, speakers)
677
+
678
+ return dict(input_ids=input_ids, labels=targets)
679
+
680
+
681
+ class LazySupervisedDataset(Dataset):
682
+ """Dataset for supervised fine-tuning."""
683
+
684
+ def __init__(
685
+ self,
686
+ data_path: str,
687
+ tokenizer: transformers.PreTrainedTokenizer,
688
+ data_args: DataArguments,
689
+ ):
690
+ super(LazySupervisedDataset, self).__init__()
691
+ list_data_dict = json.load(open(data_path, "r"))
692
+
693
+ rank0_print("Formatting inputs...Skip in lazy mode")
694
+ self.tokenizer = tokenizer
695
+ self.list_data_dict = list_data_dict
696
+ self.data_args = data_args
697
+
698
+ def __len__(self):
699
+ return len(self.list_data_dict)
700
+
701
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
702
+ sources = self.list_data_dict[i]
703
+ if isinstance(i, int):
704
+ sources = [sources]
705
+ assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
706
+ if "image" in sources[0]:
707
+ image_file = self.list_data_dict[i]["image"]
708
+ image_folder = self.data_args.image_folder
709
+ processor = self.data_args.image_processor
710
+ image = Image.open(os.path.join(image_folder, image_file)).convert("RGB")
711
+ if self.data_args.image_aspect_ratio == "pad":
712
+
713
+ def expand2square(pil_img, background_color):
714
+ width, height = pil_img.size
715
+ if width == height:
716
+ return pil_img
717
+ elif width > height:
718
+ result = Image.new(
719
+ pil_img.mode, (width, width), background_color
720
+ )
721
+ result.paste(pil_img, (0, (width - height) // 2))
722
+ return result
723
+ else:
724
+ result = Image.new(
725
+ pil_img.mode, (height, height), background_color
726
+ )
727
+ result.paste(pil_img, ((height - width) // 2, 0))
728
+ return result
729
+
730
+ image = expand2square(
731
+ image, tuple(int(x * 255) for x in processor.image_mean)
732
+ )
733
+ image = processor.preprocess(image, return_tensors="pt")[
734
+ "pixel_values"
735
+ ][0]
736
+ else:
737
+ image = processor.preprocess(image, return_tensors="pt")[
738
+ "pixel_values"
739
+ ][0]
740
+ sources = preprocess_multimodal(
741
+ copy.deepcopy([e["conversations"] for e in sources]), self.data_args
742
+ )
743
+ else:
744
+ sources = copy.deepcopy([e["conversations"] for e in sources])
745
+ data_dict = preprocess(
746
+ sources, self.tokenizer, has_image=("image" in self.list_data_dict[i])
747
+ )
748
+ if isinstance(i, int):
749
+ data_dict = dict(
750
+ input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]
751
+ )
752
+
753
+ # image exist in the data
754
+ if "image" in self.list_data_dict[i]:
755
+ data_dict["image"] = image
756
+ elif self.data_args.is_multimodal:
757
+ # image does not exist in the data, but the model is multimodal
758
+ crop_size = self.data_args.image_processor.crop_size
759
+ data_dict["image"] = torch.zeros(3, crop_size["height"], crop_size["width"])
760
+ return data_dict
761
+
762
+
763
+ @dataclass
764
+ class DataCollatorForSupervisedDataset(object):
765
+ """Collate examples for supervised fine-tuning."""
766
+
767
+ tokenizer: transformers.PreTrainedTokenizer
768
+
769
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
770
+ input_ids, labels = tuple(
771
+ [instance[key] for instance in instances] for key in ("input_ids", "labels")
772
+ )
773
+ input_ids = torch.nn.utils.rnn.pad_sequence(
774
+ input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
775
+ )
776
+ labels = torch.nn.utils.rnn.pad_sequence(
777
+ labels, batch_first=True, padding_value=IGNORE_INDEX
778
+ )
779
+ input_ids = input_ids[:, : self.tokenizer.model_max_length]
780
+ labels = labels[:, : self.tokenizer.model_max_length]
781
+ batch = dict(
782
+ input_ids=input_ids,
783
+ labels=labels,
784
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
785
+ )
786
+
787
+ if "image" in instances[0]:
788
+ images = [instance["image"] for instance in instances]
789
+ if all(x is not None and x.shape == images[0].shape for x in images):
790
+ batch["images"] = torch.stack(images)
791
+ else:
792
+ batch["images"] = images
793
+
794
+ return batch
795
+
796
+
797
+ def make_supervised_data_module(
798
+ tokenizer: transformers.PreTrainedTokenizer, data_args
799
+ ) -> Dict:
800
+ """Make dataset and collator for supervised fine-tuning."""
801
+ train_dataset = LazySupervisedDataset(
802
+ tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args
803
+ )
804
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
805
+ return dict(
806
+ train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator
807
+ )
808
+
809
+
810
+ def train():
811
+ global local_rank
812
+
813
+ parser = transformers.HfArgumentParser(
814
+ (ModelArguments, DataArguments, TrainingArguments)
815
+ )
816
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
817
+ local_rank = training_args.local_rank
818
+ compute_dtype = (
819
+ torch.float16
820
+ if training_args.fp16
821
+ else (torch.bfloat16 if training_args.bf16 else torch.float32)
822
+ )
823
+
824
+ bnb_model_from_pretrained_args = {}
825
+ if training_args.bits in [4, 8]:
826
+ from transformers import BitsAndBytesConfig
827
+
828
+ bnb_model_from_pretrained_args.update(
829
+ dict(
830
+ device_map={"": training_args.device},
831
+ load_in_4bit=training_args.bits == 4,
832
+ load_in_8bit=training_args.bits == 8,
833
+ quantization_config=BitsAndBytesConfig(
834
+ load_in_4bit=training_args.bits == 4,
835
+ load_in_8bit=training_args.bits == 8,
836
+ llm_int8_threshold=6.0,
837
+ llm_int8_has_fp16_weight=False,
838
+ bnb_4bit_compute_dtype=compute_dtype,
839
+ bnb_4bit_use_double_quant=training_args.double_quant,
840
+ bnb_4bit_quant_type=training_args.quant_type, # {'fp4', 'nf4'}
841
+ ),
842
+ )
843
+ )
844
+
845
+ if model_args.vision_tower is not None:
846
+ if "mpt" in model_args.model_name_or_path:
847
+ config = transformers.AutoConfig.from_pretrained(
848
+ model_args.model_name_or_path, trust_remote_code=True
849
+ )
850
+ config.attn_config["attn_impl"] = training_args.mpt_attn_impl
851
+ model = LlavaMPTForCausalLM.from_pretrained(
852
+ model_args.model_name_or_path,
853
+ config=config,
854
+ cache_dir=training_args.cache_dir,
855
+ **bnb_model_from_pretrained_args,
856
+ )
857
+ else:
858
+ model = LlavaLlamaForCausalLM.from_pretrained(
859
+ model_args.model_name_or_path,
860
+ cache_dir=training_args.cache_dir,
861
+ **bnb_model_from_pretrained_args,
862
+ )
863
+ else:
864
+ model = transformers.LlamaForCausalLM.from_pretrained(
865
+ model_args.model_name_or_path,
866
+ cache_dir=training_args.cache_dir,
867
+ **bnb_model_from_pretrained_args,
868
+ )
869
+ model.config.use_cache = False
870
+
871
+ if model_args.freeze_backbone:
872
+ model.model.requires_grad_(False)
873
+
874
+ if training_args.bits in [4, 8]:
875
+ from peft import prepare_model_for_kbit_training
876
+
877
+ model.config.torch_dtype = (
878
+ torch.float32
879
+ if training_args.fp16
880
+ else (torch.bfloat16 if training_args.bf16 else torch.float32)
881
+ )
882
+ model = prepare_model_for_kbit_training(
883
+ model, use_gradient_checkpointing=training_args.gradient_checkpointing
884
+ )
885
+
886
+ if training_args.gradient_checkpointing:
887
+ if hasattr(model, "enable_input_require_grads"):
888
+ model.enable_input_require_grads()
889
+ else:
890
+
891
+ def make_inputs_require_grad(module, input, output):
892
+ output.requires_grad_(True)
893
+
894
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
895
+
896
+ if training_args.lora_enable:
897
+ from peft import LoraConfig, get_peft_model
898
+
899
+ lora_config = LoraConfig(
900
+ r=training_args.lora_r,
901
+ lora_alpha=training_args.lora_alpha,
902
+ target_modules=find_all_linear_names(model),
903
+ lora_dropout=training_args.lora_dropout,
904
+ bias=training_args.lora_bias,
905
+ task_type="CAUSAL_LM",
906
+ )
907
+ if training_args.bits == 16:
908
+ if training_args.bf16:
909
+ model.to(torch.bfloat16)
910
+ if training_args.fp16:
911
+ model.to(torch.float16)
912
+ rank0_print("Adding LoRA adapters...")
913
+ model = get_peft_model(model, lora_config)
914
+
915
+ if "mpt" in model_args.model_name_or_path:
916
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
917
+ model_args.model_name_or_path,
918
+ cache_dir=training_args.cache_dir,
919
+ model_max_length=training_args.model_max_length,
920
+ padding_side="right",
921
+ )
922
+ else:
923
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
924
+ model_args.model_name_or_path,
925
+ cache_dir=training_args.cache_dir,
926
+ model_max_length=training_args.model_max_length,
927
+ padding_side="right",
928
+ use_fast=False,
929
+ )
930
+
931
+ if model_args.version == "v0":
932
+ if tokenizer.pad_token is None:
933
+ smart_tokenizer_and_embedding_resize(
934
+ special_tokens_dict=dict(pad_token="[PAD]"),
935
+ tokenizer=tokenizer,
936
+ model=model,
937
+ )
938
+ elif model_args.version == "v0.5":
939
+ tokenizer.pad_token = tokenizer.unk_token
940
+ else:
941
+ tokenizer.pad_token = tokenizer.unk_token
942
+ if model_args.version in conversation_lib.conv_templates:
943
+ conversation_lib.default_conversation = conversation_lib.conv_templates[
944
+ model_args.version
945
+ ]
946
+ else:
947
+ conversation_lib.default_conversation = conversation_lib.conv_templates[
948
+ "vicuna_v1"
949
+ ]
950
+
951
+ if model_args.vision_tower is not None:
952
+ model.get_model().initialize_vision_modules(
953
+ model_args=model_args, fsdp=training_args.fsdp
954
+ )
955
+
956
+ vision_tower = model.get_vision_tower()
957
+ vision_tower.to(dtype=torch.float16, device=training_args.device)
958
+
959
+ data_args.image_processor = vision_tower.image_processor
960
+ data_args.is_multimodal = True
961
+
962
+ model.config.image_aspect_ratio = data_args.image_aspect_ratio
963
+ model.config.image_grid_pinpoints = data_args.image_grid_pinpoints
964
+
965
+ model.config.tune_mm_mlp_adapter = (
966
+ training_args.tune_mm_mlp_adapter
967
+ ) = model_args.tune_mm_mlp_adapter
968
+ if model_args.tune_mm_mlp_adapter:
969
+ model.requires_grad_(False)
970
+ for p in model.get_model().mm_projector.parameters():
971
+ p.requires_grad = True
972
+
973
+ model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
974
+ if training_args.freeze_mm_mlp_adapter:
975
+ for p in model.get_model().mm_projector.parameters():
976
+ p.requires_grad = False
977
+
978
+ if training_args.bits in [4, 8]:
979
+ model.get_model().mm_projector.to(
980
+ dtype=compute_dtype, device=training_args.device
981
+ )
982
+
983
+ model.config.mm_use_im_start_end = (
984
+ data_args.mm_use_im_start_end
985
+ ) = model_args.mm_use_im_start_end
986
+ training_args.use_im_start_end = model_args.mm_use_im_start_end
987
+ model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
988
+ model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
989
+
990
+ if training_args.bits in [4, 8]:
991
+ from peft.tuners.lora import LoraLayer
992
+
993
+ for name, module in model.named_modules():
994
+ if isinstance(module, LoraLayer):
995
+ if training_args.bf16:
996
+ module = module.to(torch.bfloat16)
997
+ if "norm" in name:
998
+ module = module.to(torch.float32)
999
+ if "lm_head" in name or "embed_tokens" in name:
1000
+ if hasattr(module, "weight"):
1001
+ if training_args.bf16 and module.weight.dtype == torch.float32:
1002
+ module = module.to(torch.bfloat16)
1003
+
1004
+ data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
1005
+ trainer = LLaVATrainer(
1006
+ model=model, tokenizer=tokenizer, args=training_args, **data_module
1007
+ )
1008
+
1009
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
1010
+ trainer.train(resume_from_checkpoint=True)
1011
+ else:
1012
+ trainer.train()
1013
+ trainer.save_state()
1014
+
1015
+ model.config.use_cache = True
1016
+
1017
+ if training_args.lora_enable:
1018
+ state_dict = get_peft_state_maybe_zero_3(
1019
+ model.named_parameters(), training_args.lora_bias
1020
+ )
1021
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
1022
+ model.named_parameters()
1023
+ )
1024
+ if training_args.local_rank == 0 or training_args.local_rank == -1:
1025
+ model.config.save_pretrained(training_args.output_dir)
1026
+ model.save_pretrained(training_args.output_dir, state_dict=state_dict)
1027
+ torch.save(
1028
+ non_lora_state_dict,
1029
+ os.path.join(training_args.output_dir, "non_lora_trainables.bin"),
1030
+ )
1031
+ else:
1032
+ safe_save_model_for_hf_trainer(
1033
+ trainer=trainer, output_dir=training_args.output_dir
1034
+ )
1035
+
1036
+
1037
+ if __name__ == "__main__":
1038
+ train()
lisa_on_cuda/llava/train/train_mem.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2
+ # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3
+ # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
4
+
5
+ # Need to call this before importing transformers.
6
+ from .llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
7
+
8
+ replace_llama_attn_with_flash_attn()
9
+
10
+ from .train import train
11
+
12
+
13
+ if __name__ == "__main__":
14
+ train()
lisa_on_cuda/llava/utils.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import logging.handlers
4
+ import os
5
+ import sys
6
+
7
+ import requests
8
+ from .constants import LOGDIR
9
+
10
+ server_error_msg = (
11
+ "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12
+ )
13
+ moderation_msg = (
14
+ "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
15
+ )
16
+
17
+ handler = None
18
+
19
+
20
+ def build_logger(logger_name, logger_filename):
21
+ global handler
22
+
23
+ formatter = logging.Formatter(
24
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
25
+ datefmt="%Y-%m-%d %H:%M:%S",
26
+ )
27
+
28
+ # Set the format of root handlers
29
+ if not logging.getLogger().handlers:
30
+ logging.basicConfig(level=logging.INFO)
31
+ logging.getLogger().handlers[0].setFormatter(formatter)
32
+
33
+ # Redirect stdout and stderr to loggers
34
+ stdout_logger = logging.getLogger("stdout")
35
+ stdout_logger.setLevel(logging.INFO)
36
+ sl = StreamToLogger(stdout_logger, logging.INFO)
37
+ sys.stdout = sl
38
+
39
+ stderr_logger = logging.getLogger("stderr")
40
+ stderr_logger.setLevel(logging.ERROR)
41
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
42
+ sys.stderr = sl
43
+
44
+ # Get logger
45
+ logger = logging.getLogger(logger_name)
46
+ logger.setLevel(logging.INFO)
47
+
48
+ # Add a file handler for all loggers
49
+ if handler is None:
50
+ os.makedirs(LOGDIR, exist_ok=True)
51
+ filename = os.path.join(LOGDIR, logger_filename)
52
+ handler = logging.handlers.TimedRotatingFileHandler(
53
+ filename, when="D", utc=True
54
+ )
55
+ handler.setFormatter(formatter)
56
+
57
+ for name, item in logging.root.manager.loggerDict.items():
58
+ if isinstance(item, logging.Logger):
59
+ item.addHandler(handler)
60
+
61
+ return logger
62
+
63
+
64
+ class StreamToLogger(object):
65
+ """
66
+ Fake file-like stream object that redirects writes to a logger instance.
67
+ """
68
+
69
+ def __init__(self, logger, log_level=logging.INFO):
70
+ self.terminal = sys.stdout
71
+ self.logger = logger
72
+ self.log_level = log_level
73
+ self.linebuf = ""
74
+
75
+ def __getattr__(self, attr):
76
+ return getattr(self.terminal, attr)
77
+
78
+ def write(self, buf):
79
+ temp_linebuf = self.linebuf + buf
80
+ self.linebuf = ""
81
+ for line in temp_linebuf.splitlines(True):
82
+ # From the io.TextIOWrapper docs:
83
+ # On output, if newline is None, any '\n' characters written
84
+ # are translated to the system default line separator.
85
+ # By default sys.stdout.write() expects '\n' newlines and then
86
+ # translates them so this is still cross platform.
87
+ if line[-1] == "\n":
88
+ self.logger.log(self.log_level, line.rstrip())
89
+ else:
90
+ self.linebuf += line
91
+
92
+ def flush(self):
93
+ if self.linebuf != "":
94
+ self.logger.log(self.log_level, self.linebuf.rstrip())
95
+ self.linebuf = ""
96
+
97
+
98
+ def disable_torch_init():
99
+ """
100
+ Disable the redundant torch default initialization to accelerate model creation.
101
+ """
102
+ import torch
103
+
104
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
105
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
106
+
107
+
108
+ def violates_moderation(text):
109
+ """
110
+ Check whether the text violates OpenAI moderation API.
111
+ """
112
+ url = "https://api.openai.com/v1/moderations"
113
+ headers = {
114
+ "Content-Type": "application/json",
115
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"],
116
+ }
117
+ text = text.replace("\n", "")
118
+ data = "{" + '"input": ' + f'"{text}"' + "}"
119
+ data = data.encode("utf-8")
120
+ try:
121
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
122
+ flagged = ret.json()["results"][0]["flagged"]
123
+ except requests.exceptions.RequestException as e:
124
+ flagged = False
125
+ except KeyError as e:
126
+ flagged = False
127
+
128
+ return flagged
129
+
130
+
131
+ def pretty_print_semaphore(semaphore):
132
+ if semaphore is None:
133
+ return "None"
134
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
lisa_on_cuda/routes.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from fastapi import APIRouter
4
+
5
+ from lisa_on_cuda import app_logger
6
+
7
+
8
+ router = APIRouter()
9
+
10
+
11
+ @router.get("/health")
12
+ def health() -> str:
13
+ try:
14
+ from samgis_core.__version__ import __version__ as version_core
15
+ from gradio import __version__ as gradio_version
16
+
17
+ app_logger.info(f"still alive, gradio_version:{gradio_version}, version_core:{version_core}.")
18
+ return json.dumps({"msg": "lisa on cuda: still alive..."})
19
+ except Exception as e:
20
+ app_logger.error(f"exception:{e}.")
21
+ return json.dumps({"msg": "request failed"})
lisa_on_cuda/segment_anything/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .automatic_mask_generator import SamAutomaticMaskGenerator
8
+ from .build_sam import (build_sam, build_sam_vit_b, build_sam_vit_h,
9
+ build_sam_vit_l, sam_model_registry)
10
+ from .predictor import SamPredictor
lisa_on_cuda/segment_anything/automatic_mask_generator.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
+ import numpy as np
10
+ import torch
11
+ from torchvision.ops.boxes import batched_nms, box_area # type: ignore
12
+
13
+ from .modeling import Sam
14
+ from .predictor import SamPredictor
15
+ from .utils.amg import (MaskData, area_from_rle, batch_iterator,
16
+ batched_mask_to_box, box_xyxy_to_xywh,
17
+ build_all_layer_point_grids, calculate_stability_score,
18
+ coco_encode_rle, generate_crop_boxes,
19
+ is_box_near_crop_edge, mask_to_rle_pytorch,
20
+ remove_small_regions, rle_to_mask, uncrop_boxes_xyxy,
21
+ uncrop_masks, uncrop_points)
22
+
23
+
24
+ class SamAutomaticMaskGenerator:
25
+ def __init__(
26
+ self,
27
+ model: Sam,
28
+ points_per_side: Optional[int] = 32,
29
+ points_per_batch: int = 64,
30
+ pred_iou_thresh: float = 0.88,
31
+ stability_score_thresh: float = 0.95,
32
+ stability_score_offset: float = 1.0,
33
+ box_nms_thresh: float = 0.7,
34
+ crop_n_layers: int = 0,
35
+ crop_nms_thresh: float = 0.7,
36
+ crop_overlap_ratio: float = 512 / 1500,
37
+ crop_n_points_downscale_factor: int = 1,
38
+ point_grids: Optional[List[np.ndarray]] = None,
39
+ min_mask_region_area: int = 0,
40
+ output_mode: str = "binary_mask",
41
+ ) -> None:
42
+ """
43
+ Using a SAM model, generates masks for the entire image.
44
+ Generates a grid of point prompts over the image, then filters
45
+ low quality and duplicate masks. The default settings are chosen
46
+ for SAM with a ViT-H backbone.
47
+
48
+ Arguments:
49
+ model (Sam): The SAM model to use for mask prediction.
50
+ points_per_side (int or None): The number of points to be sampled
51
+ along one side of the image. The total number of points is
52
+ points_per_side**2. If None, 'point_grids' must provide explicit
53
+ point sampling.
54
+ points_per_batch (int): Sets the number of points run simultaneously
55
+ by the model. Higher numbers may be faster but use more GPU memory.
56
+ pred_iou_thresh (float): A filtering threshold in [0,1], using the
57
+ model's predicted mask quality.
58
+ stability_score_thresh (float): A filtering threshold in [0,1], using
59
+ the stability of the mask under changes to the cutoff used to binarize
60
+ the model's mask predictions.
61
+ stability_score_offset (float): The amount to shift the cutoff when
62
+ calculated the stability score.
63
+ box_nms_thresh (float): The box IoU cutoff used by non-maximal
64
+ suppression to filter duplicate masks.
65
+ crop_n_layers (int): If >0, mask prediction will be run again on
66
+ crops of the image. Sets the number of layers to run, where each
67
+ layer has 2**i_layer number of image crops.
68
+ crop_nms_thresh (float): The box IoU cutoff used by non-maximal
69
+ suppression to filter duplicate masks between different crops.
70
+ crop_overlap_ratio (float): Sets the degree to which crops overlap.
71
+ In the first crop layer, crops will overlap by this fraction of
72
+ the image length. Later layers with more crops scale down this overlap.
73
+ crop_n_points_downscale_factor (int): The number of points-per-side
74
+ sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
75
+ point_grids (list(np.ndarray) or None): A list over explicit grids
76
+ of points used for sampling, normalized to [0,1]. The nth grid in the
77
+ list is used in the nth crop layer. Exclusive with points_per_side.
78
+ min_mask_region_area (int): If >0, postprocessing will be applied
79
+ to remove disconnected regions and holes in masks with area smaller
80
+ than min_mask_region_area. Requires opencv.
81
+ output_mode (str): The form masks are returned in. Can be 'binary_mask',
82
+ 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
83
+ For large resolutions, 'binary_mask' may consume large amounts of
84
+ memory.
85
+ """
86
+
87
+ assert (points_per_side is None) != (
88
+ point_grids is None
89
+ ), "Exactly one of points_per_side or point_grid must be provided."
90
+ if points_per_side is not None:
91
+ self.point_grids = build_all_layer_point_grids(
92
+ points_per_side,
93
+ crop_n_layers,
94
+ crop_n_points_downscale_factor,
95
+ )
96
+ elif point_grids is not None:
97
+ self.point_grids = point_grids
98
+ else:
99
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
100
+
101
+ assert output_mode in [
102
+ "binary_mask",
103
+ "uncompressed_rle",
104
+ "coco_rle",
105
+ ], f"Unknown output_mode {output_mode}."
106
+ if output_mode == "coco_rle":
107
+ from pycocotools import \
108
+ mask as mask_utils # type: ignore # noqa: F401
109
+
110
+ if min_mask_region_area > 0:
111
+ import cv2 # type: ignore # noqa: F401
112
+
113
+ self.predictor = SamPredictor(model)
114
+ self.points_per_batch = points_per_batch
115
+ self.pred_iou_thresh = pred_iou_thresh
116
+ self.stability_score_thresh = stability_score_thresh
117
+ self.stability_score_offset = stability_score_offset
118
+ self.box_nms_thresh = box_nms_thresh
119
+ self.crop_n_layers = crop_n_layers
120
+ self.crop_nms_thresh = crop_nms_thresh
121
+ self.crop_overlap_ratio = crop_overlap_ratio
122
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
123
+ self.min_mask_region_area = min_mask_region_area
124
+ self.output_mode = output_mode
125
+
126
+ @torch.no_grad()
127
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
128
+ """
129
+ Generates masks for the given image.
130
+
131
+ Arguments:
132
+ image (np.ndarray): The image to generate masks for, in HWC uint8 format.
133
+
134
+ Returns:
135
+ list(dict(str, any)): A list over records for masks. Each record is
136
+ a dict containing the following keys:
137
+ segmentation (dict(str, any) or np.ndarray): The mask. If
138
+ output_mode='binary_mask', is an array of shape HW. Otherwise,
139
+ is a dictionary containing the RLE.
140
+ bbox (list(float)): The box around the mask, in XYWH format.
141
+ area (int): The area in pixels of the mask.
142
+ predicted_iou (float): The model's own prediction of the mask's
143
+ quality. This is filtered by the pred_iou_thresh parameter.
144
+ point_coords (list(list(float))): The point coordinates input
145
+ to the model to generate this mask.
146
+ stability_score (float): A measure of the mask's quality. This
147
+ is filtered on using the stability_score_thresh parameter.
148
+ crop_box (list(float)): The crop of the image used to generate
149
+ the mask, given in XYWH format.
150
+ """
151
+
152
+ # Generate masks
153
+ mask_data = self._generate_masks(image)
154
+
155
+ # Filter small disconnected regions and holes in masks
156
+ if self.min_mask_region_area > 0:
157
+ mask_data = self.postprocess_small_regions(
158
+ mask_data,
159
+ self.min_mask_region_area,
160
+ max(self.box_nms_thresh, self.crop_nms_thresh),
161
+ )
162
+
163
+ # Encode masks
164
+ if self.output_mode == "coco_rle":
165
+ mask_data["segmentations"] = [
166
+ coco_encode_rle(rle) for rle in mask_data["rles"]
167
+ ]
168
+ elif self.output_mode == "binary_mask":
169
+ mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
170
+ else:
171
+ mask_data["segmentations"] = mask_data["rles"]
172
+
173
+ # Write mask records
174
+ curr_anns = []
175
+ for idx in range(len(mask_data["segmentations"])):
176
+ ann = {
177
+ "segmentation": mask_data["segmentations"][idx],
178
+ "area": area_from_rle(mask_data["rles"][idx]),
179
+ "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
180
+ "predicted_iou": mask_data["iou_preds"][idx].item(),
181
+ "point_coords": [mask_data["points"][idx].tolist()],
182
+ "stability_score": mask_data["stability_score"][idx].item(),
183
+ "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
184
+ }
185
+ curr_anns.append(ann)
186
+
187
+ return curr_anns
188
+
189
+ def _generate_masks(self, image: np.ndarray) -> MaskData:
190
+ orig_size = image.shape[:2]
191
+ crop_boxes, layer_idxs = generate_crop_boxes(
192
+ orig_size, self.crop_n_layers, self.crop_overlap_ratio
193
+ )
194
+
195
+ # Iterate over image crops
196
+ data = MaskData()
197
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
198
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
199
+ data.cat(crop_data)
200
+
201
+ # Remove duplicate masks between crops
202
+ if len(crop_boxes) > 1:
203
+ # Prefer masks from smaller crops
204
+ scores = 1 / box_area(data["crop_boxes"])
205
+ scores = scores.to(data["boxes"].device)
206
+ keep_by_nms = batched_nms(
207
+ data["boxes"].float(),
208
+ scores,
209
+ torch.zeros_like(data["boxes"][:, 0]), # categories
210
+ iou_threshold=self.crop_nms_thresh,
211
+ )
212
+ data.filter(keep_by_nms)
213
+
214
+ data.to_numpy()
215
+ return data
216
+
217
+ def _process_crop(
218
+ self,
219
+ image: np.ndarray,
220
+ crop_box: List[int],
221
+ crop_layer_idx: int,
222
+ orig_size: Tuple[int, ...],
223
+ ) -> MaskData:
224
+ # Crop the image and calculate embeddings
225
+ x0, y0, x1, y1 = crop_box
226
+ cropped_im = image[y0:y1, x0:x1, :]
227
+ cropped_im_size = cropped_im.shape[:2]
228
+ self.predictor.set_image(cropped_im)
229
+
230
+ # Get points for this crop
231
+ points_scale = np.array(cropped_im_size)[None, ::-1]
232
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
233
+
234
+ # Generate masks for this crop in batches
235
+ data = MaskData()
236
+ for (points,) in batch_iterator(self.points_per_batch, points_for_image):
237
+ batch_data = self._process_batch(
238
+ points, cropped_im_size, crop_box, orig_size
239
+ )
240
+ data.cat(batch_data)
241
+ del batch_data
242
+ self.predictor.reset_image()
243
+
244
+ # Remove duplicates within this crop.
245
+ keep_by_nms = batched_nms(
246
+ data["boxes"].float(),
247
+ data["iou_preds"],
248
+ torch.zeros_like(data["boxes"][:, 0]), # categories
249
+ iou_threshold=self.box_nms_thresh,
250
+ )
251
+ data.filter(keep_by_nms)
252
+
253
+ # Return to the original image frame
254
+ data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
255
+ data["points"] = uncrop_points(data["points"], crop_box)
256
+ data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
257
+
258
+ return data
259
+
260
+ def _process_batch(
261
+ self,
262
+ points: np.ndarray,
263
+ im_size: Tuple[int, ...],
264
+ crop_box: List[int],
265
+ orig_size: Tuple[int, ...],
266
+ ) -> MaskData:
267
+ orig_h, orig_w = orig_size
268
+
269
+ # Run model on this batch
270
+ transformed_points = self.predictor.transform.apply_coords(points, im_size)
271
+ in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
272
+ in_labels = torch.ones(
273
+ in_points.shape[0], dtype=torch.int, device=in_points.device
274
+ )
275
+ masks, iou_preds, _ = self.predictor.predict_torch(
276
+ in_points[:, None, :],
277
+ in_labels[:, None],
278
+ multimask_output=True,
279
+ return_logits=True,
280
+ )
281
+
282
+ # Serialize predictions and store in MaskData
283
+ data = MaskData(
284
+ masks=masks.flatten(0, 1),
285
+ iou_preds=iou_preds.flatten(0, 1),
286
+ points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
287
+ )
288
+ del masks
289
+
290
+ # Filter by predicted IoU
291
+ if self.pred_iou_thresh > 0.0:
292
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
293
+ data.filter(keep_mask)
294
+
295
+ # Calculate stability score
296
+ data["stability_score"] = calculate_stability_score(
297
+ data["masks"],
298
+ self.predictor.model.mask_threshold,
299
+ self.stability_score_offset,
300
+ )
301
+ if self.stability_score_thresh > 0.0:
302
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
303
+ data.filter(keep_mask)
304
+
305
+ # Threshold masks and calculate boxes
306
+ data["masks"] = data["masks"] > self.predictor.model.mask_threshold
307
+ data["boxes"] = batched_mask_to_box(data["masks"])
308
+
309
+ # Filter boxes that touch crop boundaries
310
+ keep_mask = ~is_box_near_crop_edge(
311
+ data["boxes"], crop_box, [0, 0, orig_w, orig_h]
312
+ )
313
+ if not torch.all(keep_mask):
314
+ data.filter(keep_mask)
315
+
316
+ # Compress to RLE
317
+ data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
318
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
319
+ del data["masks"]
320
+
321
+ return data
322
+
323
+ @staticmethod
324
+ def postprocess_small_regions(
325
+ mask_data: MaskData, min_area: int, nms_thresh: float
326
+ ) -> MaskData:
327
+ """
328
+ Removes small disconnected regions and holes in masks, then reruns
329
+ box NMS to remove any new duplicates.
330
+
331
+ Edits mask_data in place.
332
+
333
+ Requires open-cv as a dependency.
334
+ """
335
+ if len(mask_data["rles"]) == 0:
336
+ return mask_data
337
+
338
+ # Filter small disconnected regions and holes
339
+ new_masks = []
340
+ scores = []
341
+ for rle in mask_data["rles"]:
342
+ mask = rle_to_mask(rle)
343
+
344
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
345
+ unchanged = not changed
346
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
347
+ unchanged = unchanged and not changed
348
+
349
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
350
+ # Give score=0 to changed masks and score=1 to unchanged masks
351
+ # so NMS will prefer ones that didn't need postprocessing
352
+ scores.append(float(unchanged))
353
+
354
+ # Recalculate boxes and remove any new duplicates
355
+ masks = torch.cat(new_masks, dim=0)
356
+ boxes = batched_mask_to_box(masks)
357
+ keep_by_nms = batched_nms(
358
+ boxes.float(),
359
+ torch.as_tensor(scores),
360
+ torch.zeros_like(boxes[:, 0]), # categories
361
+ iou_threshold=nms_thresh,
362
+ )
363
+
364
+ # Only recalculate RLEs for masks that have changed
365
+ for i_mask in keep_by_nms:
366
+ if scores[i_mask] == 0.0:
367
+ mask_torch = masks[i_mask].unsqueeze(0)
368
+ mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
369
+ mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
370
+ mask_data.filter(keep_by_nms)
371
+
372
+ return mask_data
lisa_on_cuda/segment_anything/build_sam.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from functools import partial
8
+
9
+ import torch
10
+
11
+ from .modeling import (ImageEncoderViT, MaskDecoder, PromptEncoder, Sam,
12
+ TwoWayTransformer)
13
+
14
+
15
+ def build_sam_vit_h(checkpoint=None):
16
+ return _build_sam(
17
+ encoder_embed_dim=1280,
18
+ encoder_depth=32,
19
+ encoder_num_heads=16,
20
+ encoder_global_attn_indexes=[7, 15, 23, 31],
21
+ checkpoint=checkpoint,
22
+ )
23
+
24
+
25
+ build_sam = build_sam_vit_h
26
+
27
+
28
+ def build_sam_vit_l(checkpoint=None):
29
+ return _build_sam(
30
+ encoder_embed_dim=1024,
31
+ encoder_depth=24,
32
+ encoder_num_heads=16,
33
+ encoder_global_attn_indexes=[5, 11, 17, 23],
34
+ checkpoint=checkpoint,
35
+ )
36
+
37
+
38
+ def build_sam_vit_b(checkpoint=None):
39
+ return _build_sam(
40
+ encoder_embed_dim=768,
41
+ encoder_depth=12,
42
+ encoder_num_heads=12,
43
+ encoder_global_attn_indexes=[2, 5, 8, 11],
44
+ checkpoint=checkpoint,
45
+ )
46
+
47
+
48
+ sam_model_registry = {
49
+ "default": build_sam_vit_h,
50
+ "vit_h": build_sam_vit_h,
51
+ "vit_l": build_sam_vit_l,
52
+ "vit_b": build_sam_vit_b,
53
+ }
54
+
55
+
56
+ def _build_sam(
57
+ encoder_embed_dim,
58
+ encoder_depth,
59
+ encoder_num_heads,
60
+ encoder_global_attn_indexes,
61
+ checkpoint=None,
62
+ ):
63
+ prompt_embed_dim = 256
64
+ image_size = 1024
65
+ vit_patch_size = 16
66
+ image_embedding_size = image_size // vit_patch_size
67
+ sam = Sam(
68
+ image_encoder=ImageEncoderViT(
69
+ depth=encoder_depth,
70
+ embed_dim=encoder_embed_dim,
71
+ img_size=image_size,
72
+ mlp_ratio=4,
73
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
74
+ num_heads=encoder_num_heads,
75
+ patch_size=vit_patch_size,
76
+ qkv_bias=True,
77
+ use_rel_pos=True,
78
+ global_attn_indexes=encoder_global_attn_indexes,
79
+ window_size=14,
80
+ out_chans=prompt_embed_dim,
81
+ ),
82
+ prompt_encoder=PromptEncoder(
83
+ embed_dim=prompt_embed_dim,
84
+ image_embedding_size=(image_embedding_size, image_embedding_size),
85
+ input_image_size=(image_size, image_size),
86
+ mask_in_chans=16,
87
+ ),
88
+ mask_decoder=MaskDecoder(
89
+ num_multimask_outputs=3,
90
+ transformer=TwoWayTransformer(
91
+ depth=2,
92
+ embedding_dim=prompt_embed_dim,
93
+ mlp_dim=2048,
94
+ num_heads=8,
95
+ ),
96
+ transformer_dim=prompt_embed_dim,
97
+ iou_head_depth=3,
98
+ iou_head_hidden_dim=256,
99
+ ),
100
+ pixel_mean=[123.675, 116.28, 103.53],
101
+ pixel_std=[58.395, 57.12, 57.375],
102
+ )
103
+ sam.eval()
104
+ if checkpoint is not None:
105
+ with open(checkpoint, "rb") as f:
106
+ state_dict = torch.load(f)
107
+ sam.load_state_dict(state_dict, strict=False)
108
+ return sam
lisa_on_cuda/segment_anything/modeling/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .image_encoder import ImageEncoderViT
8
+ from .mask_decoder import MaskDecoder
9
+ from .prompt_encoder import PromptEncoder
10
+ from .sam import Sam
11
+ from .transformer import TwoWayTransformer
lisa_on_cuda/segment_anything/modeling/common.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Type
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+
13
+ class MLPBlock(nn.Module):
14
+ def __init__(
15
+ self,
16
+ embedding_dim: int,
17
+ mlp_dim: int,
18
+ act: Type[nn.Module] = nn.GELU,
19
+ ) -> None:
20
+ super().__init__()
21
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23
+ self.act = act()
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ return self.lin2(self.act(self.lin1(x)))
27
+
28
+
29
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
31
+ class LayerNorm2d(nn.Module):
32
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33
+ super().__init__()
34
+ self.weight = nn.Parameter(torch.ones(num_channels))
35
+ self.bias = nn.Parameter(torch.zeros(num_channels))
36
+ self.eps = eps
37
+
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ u = x.mean(1, keepdim=True)
40
+ s = (x - u).pow(2).mean(1, keepdim=True)
41
+ x = (x - u) / torch.sqrt(s + self.eps)
42
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
43
+ return x
lisa_on_cuda/segment_anything/modeling/image_encoder.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional, Tuple, Type
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from .common import LayerNorm2d, MLPBlock
14
+
15
+
16
+ # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
17
+ class ImageEncoderViT(nn.Module):
18
+ def __init__(
19
+ self,
20
+ img_size: int = 1024,
21
+ patch_size: int = 16,
22
+ in_chans: int = 3,
23
+ embed_dim: int = 768,
24
+ depth: int = 12,
25
+ num_heads: int = 12,
26
+ mlp_ratio: float = 4.0,
27
+ out_chans: int = 256,
28
+ qkv_bias: bool = True,
29
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
30
+ act_layer: Type[nn.Module] = nn.GELU,
31
+ use_abs_pos: bool = True,
32
+ use_rel_pos: bool = False,
33
+ rel_pos_zero_init: bool = True,
34
+ window_size: int = 0,
35
+ global_attn_indexes: Tuple[int, ...] = (),
36
+ ) -> None:
37
+ """
38
+ Args:
39
+ img_size (int): Input image size.
40
+ patch_size (int): Patch size.
41
+ in_chans (int): Number of input image channels.
42
+ embed_dim (int): Patch embedding dimension.
43
+ depth (int): Depth of ViT.
44
+ num_heads (int): Number of attention heads in each ViT block.
45
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
46
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
47
+ norm_layer (nn.Module): Normalization layer.
48
+ act_layer (nn.Module): Activation layer.
49
+ use_abs_pos (bool): If True, use absolute positional embeddings.
50
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
51
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
52
+ window_size (int): Window size for window attention blocks.
53
+ global_attn_indexes (list): Indexes for blocks using global attention.
54
+ """
55
+ super().__init__()
56
+ self.img_size = img_size
57
+ self.embed_dim = embed_dim
58
+ self.out_chans = out_chans
59
+
60
+ self.patch_embed = PatchEmbed(
61
+ kernel_size=(patch_size, patch_size),
62
+ stride=(patch_size, patch_size),
63
+ in_chans=in_chans,
64
+ embed_dim=embed_dim,
65
+ )
66
+
67
+ self.pos_embed: Optional[nn.Parameter] = None
68
+ if use_abs_pos:
69
+ # Initialize absolute positional embedding with pretrain image size.
70
+ self.pos_embed = nn.Parameter(
71
+ torch.zeros(
72
+ 1, img_size // patch_size, img_size // patch_size, embed_dim
73
+ )
74
+ )
75
+
76
+ self.blocks = nn.ModuleList()
77
+ for i in range(depth):
78
+ block = Block(
79
+ dim=embed_dim,
80
+ num_heads=num_heads,
81
+ mlp_ratio=mlp_ratio,
82
+ qkv_bias=qkv_bias,
83
+ norm_layer=norm_layer,
84
+ act_layer=act_layer,
85
+ use_rel_pos=use_rel_pos,
86
+ rel_pos_zero_init=rel_pos_zero_init,
87
+ window_size=window_size if i not in global_attn_indexes else 0,
88
+ input_size=(img_size // patch_size, img_size // patch_size),
89
+ )
90
+ self.blocks.append(block)
91
+
92
+ self.neck = nn.Sequential(
93
+ nn.Conv2d(
94
+ embed_dim,
95
+ out_chans,
96
+ kernel_size=1,
97
+ bias=False,
98
+ ),
99
+ LayerNorm2d(out_chans),
100
+ nn.Conv2d(
101
+ out_chans,
102
+ out_chans,
103
+ kernel_size=3,
104
+ padding=1,
105
+ bias=False,
106
+ ),
107
+ LayerNorm2d(out_chans),
108
+ )
109
+
110
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
111
+ x = self.patch_embed(x)
112
+ if self.pos_embed is not None:
113
+ x = x + self.pos_embed
114
+
115
+ for blk in self.blocks:
116
+ x = blk(x)
117
+
118
+ dtype = x.dtype
119
+ if dtype == torch.float16: # prevent overflow
120
+ with torch.autocast(device_type="cuda", dtype=torch.float32):
121
+ x = self.neck(x.permute(0, 3, 1, 2))
122
+ x = x.to(dtype)
123
+ else:
124
+ x = self.neck(x.permute(0, 3, 1, 2))
125
+ return x
126
+
127
+
128
+ class Block(nn.Module):
129
+ """Transformer blocks with support of window attention and residual propagation blocks"""
130
+
131
+ def __init__(
132
+ self,
133
+ dim: int,
134
+ num_heads: int,
135
+ mlp_ratio: float = 4.0,
136
+ qkv_bias: bool = True,
137
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
138
+ act_layer: Type[nn.Module] = nn.GELU,
139
+ use_rel_pos: bool = False,
140
+ rel_pos_zero_init: bool = True,
141
+ window_size: int = 0,
142
+ input_size: Optional[Tuple[int, int]] = None,
143
+ ) -> None:
144
+ """
145
+ Args:
146
+ dim (int): Number of input channels.
147
+ num_heads (int): Number of attention heads in each ViT block.
148
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
149
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
150
+ norm_layer (nn.Module): Normalization layer.
151
+ act_layer (nn.Module): Activation layer.
152
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
153
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
154
+ window_size (int): Window size for window attention blocks. If it equals 0, then
155
+ use global attention.
156
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
157
+ positional parameter size.
158
+ """
159
+ super().__init__()
160
+ self.norm1 = norm_layer(dim)
161
+ self.attn = Attention(
162
+ dim,
163
+ num_heads=num_heads,
164
+ qkv_bias=qkv_bias,
165
+ use_rel_pos=use_rel_pos,
166
+ rel_pos_zero_init=rel_pos_zero_init,
167
+ input_size=input_size if window_size == 0 else (window_size, window_size),
168
+ )
169
+
170
+ self.norm2 = norm_layer(dim)
171
+ self.mlp = MLPBlock(
172
+ embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer
173
+ )
174
+
175
+ self.window_size = window_size
176
+
177
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
178
+ shortcut = x
179
+ x = self.norm1(x)
180
+ # Window partition
181
+ if self.window_size > 0:
182
+ H, W = x.shape[1], x.shape[2]
183
+ x, pad_hw = window_partition(x, self.window_size)
184
+
185
+ x = self.attn(x)
186
+ # Reverse window partition
187
+ if self.window_size > 0:
188
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
189
+
190
+ x = shortcut + x
191
+ x = x + self.mlp(self.norm2(x))
192
+
193
+ return x
194
+
195
+
196
+ class Attention(nn.Module):
197
+ """Multi-head Attention block with relative position embeddings."""
198
+
199
+ def __init__(
200
+ self,
201
+ dim: int,
202
+ num_heads: int = 8,
203
+ qkv_bias: bool = True,
204
+ use_rel_pos: bool = False,
205
+ rel_pos_zero_init: bool = True,
206
+ input_size: Optional[Tuple[int, int]] = None,
207
+ ) -> None:
208
+ """
209
+ Args:
210
+ dim (int): Number of input channels.
211
+ num_heads (int): Number of attention heads.
212
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
213
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
214
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
215
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
216
+ positional parameter size.
217
+ """
218
+ super().__init__()
219
+ self.num_heads = num_heads
220
+ head_dim = dim // num_heads
221
+ self.scale = head_dim**-0.5
222
+
223
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
224
+ self.proj = nn.Linear(dim, dim)
225
+
226
+ self.use_rel_pos = use_rel_pos
227
+ if self.use_rel_pos:
228
+ assert (
229
+ input_size is not None
230
+ ), "Input size must be provided if using relative positional encoding."
231
+ # initialize relative positional embeddings
232
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
233
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
234
+
235
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
236
+ B, H, W, _ = x.shape
237
+ # qkv with shape (3, B, nHead, H * W, C)
238
+ qkv = (
239
+ self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
240
+ )
241
+ # q, k, v with shape (B * nHead, H * W, C)
242
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
243
+
244
+ attn = (q * self.scale) @ k.transpose(-2, -1)
245
+
246
+ if self.use_rel_pos:
247
+ attn = add_decomposed_rel_pos(
248
+ attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)
249
+ )
250
+
251
+ attn = attn.softmax(dim=-1)
252
+ x = (
253
+ (attn @ v)
254
+ .view(B, self.num_heads, H, W, -1)
255
+ .permute(0, 2, 3, 1, 4)
256
+ .reshape(B, H, W, -1)
257
+ )
258
+ x = self.proj(x)
259
+
260
+ return x
261
+
262
+
263
+ def window_partition(
264
+ x: torch.Tensor, window_size: int
265
+ ) -> Tuple[torch.Tensor, Tuple[int, int]]:
266
+ """
267
+ Partition into non-overlapping windows with padding if needed.
268
+ Args:
269
+ x (tensor): input tokens with [B, H, W, C].
270
+ window_size (int): window size.
271
+
272
+ Returns:
273
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
274
+ (Hp, Wp): padded height and width before partition
275
+ """
276
+ B, H, W, C = x.shape
277
+
278
+ pad_h = (window_size - H % window_size) % window_size
279
+ pad_w = (window_size - W % window_size) % window_size
280
+ if pad_h > 0 or pad_w > 0:
281
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
282
+ Hp, Wp = H + pad_h, W + pad_w
283
+
284
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
285
+ windows = (
286
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
287
+ )
288
+ return windows, (Hp, Wp)
289
+
290
+
291
+ def window_unpartition(
292
+ windows: torch.Tensor,
293
+ window_size: int,
294
+ pad_hw: Tuple[int, int],
295
+ hw: Tuple[int, int],
296
+ ) -> torch.Tensor:
297
+ """
298
+ Window unpartition into original sequences and removing padding.
299
+ Args:
300
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
301
+ window_size (int): window size.
302
+ pad_hw (Tuple): padded height and width (Hp, Wp).
303
+ hw (Tuple): original height and width (H, W) before padding.
304
+
305
+ Returns:
306
+ x: unpartitioned sequences with [B, H, W, C].
307
+ """
308
+ Hp, Wp = pad_hw
309
+ H, W = hw
310
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
311
+ x = windows.view(
312
+ B, Hp // window_size, Wp // window_size, window_size, window_size, -1
313
+ )
314
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
315
+
316
+ if Hp > H or Wp > W:
317
+ x = x[:, :H, :W, :].contiguous()
318
+ return x
319
+
320
+
321
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
322
+ """
323
+ Get relative positional embeddings according to the relative positions of
324
+ query and key sizes.
325
+ Args:
326
+ q_size (int): size of query q.
327
+ k_size (int): size of key k.
328
+ rel_pos (Tensor): relative position embeddings (L, C).
329
+
330
+ Returns:
331
+ Extracted positional embeddings according to relative positions.
332
+ """
333
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
334
+ # Interpolate rel pos if needed.
335
+ if rel_pos.shape[0] != max_rel_dist:
336
+ # Interpolate rel pos.
337
+ rel_pos_resized = F.interpolate(
338
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
339
+ size=max_rel_dist,
340
+ mode="linear",
341
+ )
342
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
343
+ else:
344
+ rel_pos_resized = rel_pos
345
+
346
+ # Scale the coords with short length if shapes for q and k are different.
347
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
348
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
349
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
350
+
351
+ return rel_pos_resized[relative_coords.long()]
352
+
353
+
354
+ def add_decomposed_rel_pos(
355
+ attn: torch.Tensor,
356
+ q: torch.Tensor,
357
+ rel_pos_h: torch.Tensor,
358
+ rel_pos_w: torch.Tensor,
359
+ q_size: Tuple[int, int],
360
+ k_size: Tuple[int, int],
361
+ ) -> torch.Tensor:
362
+ """
363
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
364
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
365
+ Args:
366
+ attn (Tensor): attention map.
367
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
368
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
369
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
370
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
371
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
372
+
373
+ Returns:
374
+ attn (Tensor): attention map with added relative positional embeddings.
375
+ """
376
+ q_h, q_w = q_size
377
+ k_h, k_w = k_size
378
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
379
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
380
+
381
+ B, _, dim = q.shape
382
+ r_q = q.reshape(B, q_h, q_w, dim)
383
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
384
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
385
+
386
+ attn = (
387
+ attn.view(B, q_h, q_w, k_h, k_w)
388
+ + rel_h[:, :, :, :, None]
389
+ + rel_w[:, :, :, None, :]
390
+ ).view(B, q_h * q_w, k_h * k_w)
391
+
392
+ return attn
393
+
394
+
395
+ class PatchEmbed(nn.Module):
396
+ """
397
+ Image to Patch Embedding.
398
+ """
399
+
400
+ def __init__(
401
+ self,
402
+ kernel_size: Tuple[int, int] = (16, 16),
403
+ stride: Tuple[int, int] = (16, 16),
404
+ padding: Tuple[int, int] = (0, 0),
405
+ in_chans: int = 3,
406
+ embed_dim: int = 768,
407
+ ) -> None:
408
+ """
409
+ Args:
410
+ kernel_size (Tuple): kernel size of the projection layer.
411
+ stride (Tuple): stride of the projection layer.
412
+ padding (Tuple): padding size of the projection layer.
413
+ in_chans (int): Number of input image channels.
414
+ embed_dim (int): Patch embedding dimension.
415
+ """
416
+ super().__init__()
417
+
418
+ self.proj = nn.Conv2d(
419
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
420
+ )
421
+
422
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
423
+ x = self.proj(x)
424
+ # B C H W -> B H W C
425
+ x = x.permute(0, 2, 3, 1)
426
+ return x
lisa_on_cuda/segment_anything/modeling/mask_decoder.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Tuple, Type
8
+
9
+ import torch
10
+ from torch import nn
11
+ from torch.nn import functional as F
12
+
13
+ from .common import LayerNorm2d
14
+
15
+
16
+ class MaskDecoder(nn.Module):
17
+ def __init__(
18
+ self,
19
+ *,
20
+ transformer_dim: int,
21
+ transformer: nn.Module,
22
+ num_multimask_outputs: int = 3,
23
+ activation: Type[nn.Module] = nn.GELU,
24
+ iou_head_depth: int = 3,
25
+ iou_head_hidden_dim: int = 256,
26
+ ) -> None:
27
+ """
28
+ Predicts masks given an image and prompt embeddings, using a
29
+ transformer architecture.
30
+
31
+ Arguments:
32
+ transformer_dim (int): the channel dimension of the transformer
33
+ transformer (nn.Module): the transformer used to predict masks
34
+ num_multimask_outputs (int): the number of masks to predict
35
+ when disambiguating masks
36
+ activation (nn.Module): the type of activation to use when
37
+ upscaling masks
38
+ iou_head_depth (int): the depth of the MLP used to predict
39
+ mask quality
40
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
41
+ used to predict mask quality
42
+ """
43
+ super().__init__()
44
+ self.transformer_dim = transformer_dim
45
+ self.transformer = transformer
46
+
47
+ self.num_multimask_outputs = num_multimask_outputs
48
+
49
+ self.iou_token = nn.Embedding(1, transformer_dim)
50
+ self.num_mask_tokens = num_multimask_outputs + 1
51
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
52
+
53
+ self.output_upscaling = nn.Sequential(
54
+ nn.ConvTranspose2d(
55
+ transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
56
+ ),
57
+ LayerNorm2d(transformer_dim // 4),
58
+ activation(),
59
+ nn.ConvTranspose2d(
60
+ transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
61
+ ),
62
+ activation(),
63
+ )
64
+ self.output_hypernetworks_mlps = nn.ModuleList(
65
+ [
66
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
67
+ for i in range(self.num_mask_tokens)
68
+ ]
69
+ )
70
+
71
+ self.iou_prediction_head = MLP(
72
+ transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
73
+ )
74
+
75
+ def forward(
76
+ self,
77
+ image_embeddings: torch.Tensor,
78
+ image_pe: torch.Tensor,
79
+ sparse_prompt_embeddings: torch.Tensor,
80
+ dense_prompt_embeddings: torch.Tensor,
81
+ multimask_output: bool,
82
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
83
+ """
84
+ Predict masks given image and prompt embeddings.
85
+
86
+ Arguments:
87
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
88
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
89
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
90
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
91
+ multimask_output (bool): Whether to return multiple masks or a single
92
+ mask.
93
+
94
+ Returns:
95
+ torch.Tensor: batched predicted masks
96
+ torch.Tensor: batched predictions of mask quality
97
+ """
98
+ masks, iou_pred = self.predict_masks(
99
+ image_embeddings=image_embeddings,
100
+ image_pe=image_pe,
101
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
102
+ dense_prompt_embeddings=dense_prompt_embeddings,
103
+ )
104
+
105
+ # Select the correct mask or masks for output
106
+ if multimask_output:
107
+ mask_slice = slice(1, None)
108
+ else:
109
+ mask_slice = slice(0, 1)
110
+ masks = masks[:, mask_slice, :, :]
111
+ iou_pred = iou_pred[:, mask_slice]
112
+
113
+ # Prepare output
114
+ return masks, iou_pred
115
+
116
+ def predict_masks(
117
+ self,
118
+ image_embeddings: torch.Tensor,
119
+ image_pe: torch.Tensor,
120
+ sparse_prompt_embeddings: torch.Tensor,
121
+ dense_prompt_embeddings: torch.Tensor,
122
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
123
+ """Predicts masks. See 'forward' for more details."""
124
+ # Concatenate output tokens
125
+ output_tokens = torch.cat(
126
+ [self.iou_token.weight, self.mask_tokens.weight], dim=0
127
+ )
128
+ output_tokens = output_tokens.unsqueeze(0).expand(
129
+ sparse_prompt_embeddings.size(0), -1, -1
130
+ )
131
+
132
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
133
+
134
+ # image_embeddings: [1, C, H, W], tokens: [B, N, C]
135
+ # dense_prompt_embeddings: [B, C, H, W]
136
+ # Expand per-image data in batch direction to be per-mask
137
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
138
+ src = src + dense_prompt_embeddings
139
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
140
+ b, c, h, w = src.shape
141
+
142
+ # Run the transformer
143
+ hs, src = self.transformer(src, pos_src, tokens)
144
+ iou_token_out = hs[:, 0, :]
145
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
146
+
147
+ # Upscale mask embeddings and predict masks using the mask tokens
148
+ src = src.transpose(1, 2).view(b, c, h, w)
149
+ upscaled_embedding = self.output_upscaling(src)
150
+ hyper_in_list: List[torch.Tensor] = []
151
+ for i in range(self.num_mask_tokens):
152
+ hyper_in_list.append(
153
+ self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
154
+ )
155
+ hyper_in = torch.stack(hyper_in_list, dim=1)
156
+ b, c, h, w = upscaled_embedding.shape
157
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(
158
+ b, self.num_mask_tokens, h, w
159
+ )
160
+
161
+ # Generate mask quality predictions
162
+ iou_pred = self.iou_prediction_head(iou_token_out)
163
+
164
+ return masks, iou_pred
165
+
166
+
167
+ # Lightly adapted from
168
+ # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
169
+ class MLP(nn.Module):
170
+ def __init__(
171
+ self,
172
+ input_dim: int,
173
+ hidden_dim: int,
174
+ output_dim: int,
175
+ num_layers: int,
176
+ sigmoid_output: bool = False,
177
+ ) -> None:
178
+ super().__init__()
179
+ self.num_layers = num_layers
180
+ h = [hidden_dim] * (num_layers - 1)
181
+ self.layers = nn.ModuleList(
182
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
183
+ )
184
+ self.sigmoid_output = sigmoid_output
185
+
186
+ def forward(self, x):
187
+ for i, layer in enumerate(self.layers):
188
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
189
+ if self.sigmoid_output:
190
+ x = F.sigmoid(x)
191
+ return x
lisa_on_cuda/segment_anything/modeling/prompt_encoder.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Any, Optional, Tuple, Type
8
+
9
+ import numpy as np
10
+ import torch
11
+ from torch import nn
12
+
13
+ from .common import LayerNorm2d
14
+
15
+
16
+ class PromptEncoder(nn.Module):
17
+ def __init__(
18
+ self,
19
+ embed_dim: int,
20
+ image_embedding_size: Tuple[int, int],
21
+ input_image_size: Tuple[int, int],
22
+ mask_in_chans: int,
23
+ activation: Type[nn.Module] = nn.GELU,
24
+ ) -> None:
25
+ """
26
+ Encodes prompts for input to SAM's mask decoder.
27
+
28
+ Arguments:
29
+ embed_dim (int): The prompts' embedding dimension
30
+ image_embedding_size (tuple(int, int)): The spatial size of the
31
+ image embedding, as (H, W).
32
+ input_image_size (int): The padded size of the image as input
33
+ to the image encoder, as (H, W).
34
+ mask_in_chans (int): The number of hidden channels used for
35
+ encoding input masks.
36
+ activation (nn.Module): The activation to use when encoding
37
+ input masks.
38
+ """
39
+ super().__init__()
40
+ self.embed_dim = embed_dim
41
+ self.input_image_size = input_image_size
42
+ self.image_embedding_size = image_embedding_size
43
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
44
+
45
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
46
+ point_embeddings = [
47
+ nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
48
+ ]
49
+ self.point_embeddings = nn.ModuleList(point_embeddings)
50
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
51
+
52
+ self.mask_input_size = (
53
+ 4 * image_embedding_size[0],
54
+ 4 * image_embedding_size[1],
55
+ )
56
+ self.mask_downscaling = nn.Sequential(
57
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
58
+ LayerNorm2d(mask_in_chans // 4),
59
+ activation(),
60
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
61
+ LayerNorm2d(mask_in_chans),
62
+ activation(),
63
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
64
+ )
65
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
66
+
67
+ def get_dense_pe(self) -> torch.Tensor:
68
+ """
69
+ Returns the positional encoding used to encode point prompts,
70
+ applied to a dense set of points the shape of the image encoding.
71
+
72
+ Returns:
73
+ torch.Tensor: Positional encoding with shape
74
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
75
+ """
76
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
77
+
78
+ def _embed_points(
79
+ self,
80
+ points: torch.Tensor,
81
+ labels: torch.Tensor,
82
+ pad: bool,
83
+ ) -> torch.Tensor:
84
+ """Embeds point prompts."""
85
+ points = points + 0.5 # Shift to center of pixel
86
+ if pad:
87
+ padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
88
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
89
+ points = torch.cat([points, padding_point], dim=1)
90
+ labels = torch.cat([labels, padding_label], dim=1)
91
+ point_embedding = self.pe_layer.forward_with_coords(
92
+ points, self.input_image_size
93
+ )
94
+ point_embedding[labels == -1] = 0.0
95
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
96
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
97
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
98
+ return point_embedding
99
+
100
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
101
+ """Embeds box prompts."""
102
+ boxes = boxes + 0.5 # Shift to center of pixel
103
+ coords = boxes.reshape(-1, 2, 2)
104
+ corner_embedding = self.pe_layer.forward_with_coords(
105
+ coords, self.input_image_size
106
+ )
107
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
108
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
109
+ return corner_embedding
110
+
111
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
112
+ """Embeds mask inputs."""
113
+ mask_embedding = self.mask_downscaling(masks)
114
+ return mask_embedding
115
+
116
+ def _get_batch_size(
117
+ self,
118
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
119
+ boxes: Optional[torch.Tensor],
120
+ masks: Optional[torch.Tensor],
121
+ text_embeds: Optional[torch.Tensor],
122
+ ) -> int:
123
+ """
124
+ Gets the batch size of the output given the batch size of the input prompts.
125
+ """
126
+ if points is not None:
127
+ return points[0].shape[0]
128
+ elif boxes is not None:
129
+ return boxes.shape[0]
130
+ elif masks is not None:
131
+ return masks.shape[0]
132
+ elif text_embeds is not None:
133
+ return text_embeds.shape[0]
134
+ else:
135
+ return 1
136
+
137
+ def _get_device(self) -> torch.device:
138
+ return self.point_embeddings[0].weight.device
139
+
140
+ def forward(
141
+ self,
142
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
143
+ boxes: Optional[torch.Tensor],
144
+ masks: Optional[torch.Tensor],
145
+ text_embeds: Optional[torch.Tensor],
146
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
147
+ """
148
+ Embeds different types of prompts, returning both sparse and dense
149
+ embeddings.
150
+
151
+ Arguments:
152
+ points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
153
+ and labels to embed.
154
+ boxes (torch.Tensor or none): boxes to embed
155
+ masks (torch.Tensor or none): masks to embed
156
+
157
+ Returns:
158
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
159
+ BxNx(embed_dim), where N is determined by the number of input points
160
+ and boxes.
161
+ torch.Tensor: dense embeddings for the masks, in the shape
162
+ Bx(embed_dim)x(embed_H)x(embed_W)
163
+ """
164
+ bs = self._get_batch_size(points, boxes, masks, text_embeds)
165
+ sparse_embeddings = torch.empty(
166
+ (bs, 0, self.embed_dim), device=self._get_device()
167
+ )
168
+ if points is not None:
169
+ coords, labels = points
170
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
171
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
172
+ if boxes is not None:
173
+ box_embeddings = self._embed_boxes(boxes)
174
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
175
+
176
+ if text_embeds is not None:
177
+ sparse_embeddings = torch.cat([sparse_embeddings, text_embeds], dim=1)
178
+
179
+ if masks is not None:
180
+ dense_embeddings = self._embed_masks(masks)
181
+ else:
182
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
183
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
184
+ )
185
+
186
+ return sparse_embeddings, dense_embeddings
187
+
188
+
189
+ class PositionEmbeddingRandom(nn.Module):
190
+ """
191
+ Positional encoding using random spatial frequencies.
192
+ """
193
+
194
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
195
+ super().__init__()
196
+ if scale is None or scale <= 0.0:
197
+ scale = 1.0
198
+ self.register_buffer(
199
+ "positional_encoding_gaussian_matrix",
200
+ scale * torch.randn((2, num_pos_feats)),
201
+ )
202
+
203
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
204
+ """Positionally encode points that are normalized to [0,1]."""
205
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
206
+ coords = 2 * coords - 1
207
+
208
+ if coords.dtype != self.positional_encoding_gaussian_matrix.dtype:
209
+ coords = coords.to(self.positional_encoding_gaussian_matrix.dtype)
210
+
211
+ coords = coords @ self.positional_encoding_gaussian_matrix
212
+ coords = 2 * np.pi * coords
213
+ # outputs d_1 x ... x d_n x C shape
214
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
215
+
216
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
217
+ """Generate positional encoding for a grid of the specified size."""
218
+ h, w = size
219
+ device: Any = self.positional_encoding_gaussian_matrix.device
220
+ grid = torch.ones(
221
+ (h, w), device=device, dtype=self.positional_encoding_gaussian_matrix.dtype
222
+ )
223
+ y_embed = grid.cumsum(dim=0) - 0.5
224
+ x_embed = grid.cumsum(dim=1) - 0.5
225
+ y_embed = y_embed / h
226
+ x_embed = x_embed / w
227
+
228
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
229
+ return pe.permute(2, 0, 1) # C x H x W
230
+
231
+ def forward_with_coords(
232
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
233
+ ) -> torch.Tensor:
234
+ """Positionally encode points that are not normalized to [0,1]."""
235
+ coords = coords_input.clone()
236
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
237
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
238
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
lisa_on_cuda/segment_anything/modeling/sam.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Any, Dict, List, Tuple
8
+
9
+ import torch
10
+ from torch import nn
11
+ from torch.nn import functional as F
12
+
13
+ from .image_encoder import ImageEncoderViT
14
+ from .mask_decoder import MaskDecoder
15
+ from .prompt_encoder import PromptEncoder
16
+
17
+
18
+ class Sam(nn.Module):
19
+ mask_threshold: float = 0.0
20
+ image_format: str = "RGB"
21
+
22
+ def __init__(
23
+ self,
24
+ image_encoder: ImageEncoderViT,
25
+ prompt_encoder: PromptEncoder,
26
+ mask_decoder: MaskDecoder,
27
+ pixel_mean: List[float] = [123.675, 116.28, 103.53],
28
+ pixel_std: List[float] = [58.395, 57.12, 57.375],
29
+ ) -> None:
30
+ """
31
+ SAM predicts object masks from an image and input prompts.
32
+
33
+ Arguments:
34
+ image_encoder (ImageEncoderViT): The backbone used to encode the
35
+ image into image embeddings that allow for efficient mask prediction.
36
+ prompt_encoder (PromptEncoder): Encodes various types of input prompts.
37
+ mask_decoder (MaskDecoder): Predicts masks from the image embeddings
38
+ and encoded prompts.
39
+ pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
40
+ pixel_std (list(float)): Std values for normalizing pixels in the input image.
41
+ """
42
+ super().__init__()
43
+ self.image_encoder = image_encoder
44
+ self.prompt_encoder = prompt_encoder
45
+ self.mask_decoder = mask_decoder
46
+ self.register_buffer(
47
+ "pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False
48
+ )
49
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
50
+
51
+ @property
52
+ def device(self) -> Any:
53
+ return self.pixel_mean.device
54
+
55
+ @torch.no_grad()
56
+ def forward(
57
+ self,
58
+ batched_input: List[Dict[str, Any]],
59
+ multimask_output: bool,
60
+ ) -> List[Dict[str, torch.Tensor]]:
61
+ """
62
+ Predicts masks end-to-end from provided images and prompts.
63
+ If prompts are not known in advance, using SamPredictor is
64
+ recommended over calling the model directly.
65
+
66
+ Arguments:
67
+ batched_input (list(dict)): A list over input images, each a
68
+ dictionary with the following keys. A prompt key can be
69
+ excluded if it is not present.
70
+ 'image': The image as a torch tensor in 3xHxW format,
71
+ already transformed for input to the model.
72
+ 'original_size': (tuple(int, int)) The original size of
73
+ the image before transformation, as (H, W).
74
+ 'point_coords': (torch.Tensor) Batched point prompts for
75
+ this image, with shape BxNx2. Already transformed to the
76
+ input frame of the model.
77
+ 'point_labels': (torch.Tensor) Batched labels for point prompts,
78
+ with shape BxN.
79
+ 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
80
+ Already transformed to the input frame of the model.
81
+ 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
82
+ in the form Bx1xHxW.
83
+ multimask_output (bool): Whether the model should predict multiple
84
+ disambiguating masks, or return a single mask.
85
+
86
+ Returns:
87
+ (list(dict)): A list over input images, where each element is
88
+ as dictionary with the following keys.
89
+ 'masks': (torch.Tensor) Batched binary mask predictions,
90
+ with shape BxCxHxW, where B is the number of input prompts,
91
+ C is determined by multimask_output, and (H, W) is the
92
+ original size of the image.
93
+ 'iou_predictions': (torch.Tensor) The model's predictions
94
+ of mask quality, in shape BxC.
95
+ 'low_res_logits': (torch.Tensor) Low resolution logits with
96
+ shape BxCxHxW, where H=W=256. Can be passed as mask input
97
+ to subsequent iterations of prediction.
98
+ """
99
+ input_images = torch.stack(
100
+ [self.preprocess(x["image"]) for x in batched_input], dim=0
101
+ )
102
+ image_embeddings = self.image_encoder(input_images)
103
+
104
+ outputs = []
105
+ for image_record, curr_embedding in zip(batched_input, image_embeddings):
106
+ if "point_coords" in image_record:
107
+ points = (image_record["point_coords"], image_record["point_labels"])
108
+ else:
109
+ points = None
110
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
111
+ points=points,
112
+ boxes=image_record.get("boxes", None),
113
+ masks=image_record.get("mask_inputs", None),
114
+ )
115
+ low_res_masks, iou_predictions = self.mask_decoder(
116
+ image_embeddings=curr_embedding.unsqueeze(0),
117
+ image_pe=self.prompt_encoder.get_dense_pe(),
118
+ sparse_prompt_embeddings=sparse_embeddings,
119
+ dense_prompt_embeddings=dense_embeddings,
120
+ multimask_output=multimask_output,
121
+ )
122
+ masks = self.postprocess_masks(
123
+ low_res_masks,
124
+ input_size=image_record["image"].shape[-2:],
125
+ original_size=image_record["original_size"],
126
+ )
127
+ masks = masks > self.mask_threshold
128
+ outputs.append(
129
+ {
130
+ "masks": masks,
131
+ "iou_predictions": iou_predictions,
132
+ "low_res_logits": low_res_masks,
133
+ }
134
+ )
135
+ return outputs
136
+
137
+ def postprocess_masks(
138
+ self,
139
+ masks: torch.Tensor,
140
+ input_size: Tuple[int, ...],
141
+ original_size: Tuple[int, ...],
142
+ ) -> torch.Tensor:
143
+ """
144
+ Remove padding and upscale masks to the original image size.
145
+
146
+ Arguments:
147
+ masks (torch.Tensor): Batched masks from the mask_decoder,
148
+ in BxCxHxW format.
149
+ input_size (tuple(int, int)): The size of the image input to the
150
+ model, in (H, W) format. Used to remove padding.
151
+ original_size (tuple(int, int)): The original size of the image
152
+ before resizing for input to the model, in (H, W) format.
153
+
154
+ Returns:
155
+ (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
156
+ is given by original_size.
157
+ """
158
+
159
+ dtype = masks.dtype
160
+
161
+ masks = F.interpolate(
162
+ masks.float(),
163
+ (self.image_encoder.img_size, self.image_encoder.img_size),
164
+ mode="bilinear",
165
+ align_corners=False,
166
+ )
167
+ # masks = masks.to(dtype)
168
+ masks = masks[..., : input_size[0], : input_size[1]]
169
+ masks = F.interpolate(
170
+ masks, original_size, mode="bilinear", align_corners=False
171
+ )
172
+ return masks
173
+
174
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
175
+ """Normalize pixel values and pad to a square input."""
176
+ # Normalize colors
177
+ x = (x - self.pixel_mean) / self.pixel_std
178
+
179
+ # Pad
180
+ h, w = x.shape[-2:]
181
+ padh = self.image_encoder.img_size - h
182
+ padw = self.image_encoder.img_size - w
183
+ x = F.pad(x, (0, padw, 0, padh))
184
+ return x
lisa_on_cuda/segment_anything/modeling/transformer.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Tuple, Type
9
+
10
+ import torch
11
+ from torch import Tensor, nn
12
+
13
+ from .common import MLPBlock
14
+
15
+
16
+ class TwoWayTransformer(nn.Module):
17
+ def __init__(
18
+ self,
19
+ depth: int,
20
+ embedding_dim: int,
21
+ num_heads: int,
22
+ mlp_dim: int,
23
+ activation: Type[nn.Module] = nn.ReLU,
24
+ attention_downsample_rate: int = 2,
25
+ ) -> None:
26
+ """
27
+ A transformer decoder that attends to an input image using
28
+ queries whose positional embedding is supplied.
29
+
30
+ Args:
31
+ depth (int): number of layers in the transformer
32
+ embedding_dim (int): the channel dimension for the input embeddings
33
+ num_heads (int): the number of heads for multihead attention. Must
34
+ divide embedding_dim
35
+ mlp_dim (int): the channel dimension internal to the MLP block
36
+ activation (nn.Module): the activation to use in the MLP block
37
+ """
38
+ super().__init__()
39
+ self.depth = depth
40
+ self.embedding_dim = embedding_dim
41
+ self.num_heads = num_heads
42
+ self.mlp_dim = mlp_dim
43
+ self.layers = nn.ModuleList()
44
+
45
+ for i in range(depth):
46
+ self.layers.append(
47
+ TwoWayAttentionBlock(
48
+ embedding_dim=embedding_dim,
49
+ num_heads=num_heads,
50
+ mlp_dim=mlp_dim,
51
+ activation=activation,
52
+ attention_downsample_rate=attention_downsample_rate,
53
+ skip_first_layer_pe=(i == 0),
54
+ )
55
+ )
56
+
57
+ self.final_attn_token_to_image = Attention(
58
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
59
+ )
60
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
61
+
62
+ def forward(
63
+ self,
64
+ image_embedding: Tensor,
65
+ image_pe: Tensor,
66
+ point_embedding: Tensor,
67
+ ) -> Tuple[Tensor, Tensor]:
68
+ """
69
+ Args:
70
+ image_embedding (torch.Tensor): image to attend to. Should be shape
71
+ B x embedding_dim x h x w for any h and w.
72
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
73
+ have the same shape as image_embedding.
74
+ point_embedding (torch.Tensor): the embedding to add to the query points.
75
+ Must have shape B x N_points x embedding_dim for any N_points.
76
+
77
+ Returns:
78
+ torch.Tensor: the processed point_embedding
79
+ torch.Tensor: the processed image_embedding
80
+ """
81
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
82
+ bs, c, h, w = image_embedding.shape
83
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
84
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
85
+
86
+ # Prepare queries
87
+ queries = point_embedding
88
+ keys = image_embedding
89
+
90
+ # Apply transformer blocks and final layernorm
91
+ for layer in self.layers:
92
+ queries, keys = layer(
93
+ queries=queries,
94
+ keys=keys,
95
+ query_pe=point_embedding,
96
+ key_pe=image_pe,
97
+ )
98
+
99
+ # Apply the final attention layer from the points to the image
100
+ q = queries + point_embedding
101
+ k = keys + image_pe
102
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
103
+ queries = queries + attn_out
104
+ queries = self.norm_final_attn(queries)
105
+
106
+ return queries, keys
107
+
108
+
109
+ class TwoWayAttentionBlock(nn.Module):
110
+ def __init__(
111
+ self,
112
+ embedding_dim: int,
113
+ num_heads: int,
114
+ mlp_dim: int = 2048,
115
+ activation: Type[nn.Module] = nn.ReLU,
116
+ attention_downsample_rate: int = 2,
117
+ skip_first_layer_pe: bool = False,
118
+ ) -> None:
119
+ """
120
+ A transformer block with four layers: (1) self-attention of sparse
121
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
122
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
123
+ inputs.
124
+
125
+ Arguments:
126
+ embedding_dim (int): the channel dimension of the embeddings
127
+ num_heads (int): the number of heads in the attention layers
128
+ mlp_dim (int): the hidden dimension of the mlp block
129
+ activation (nn.Module): the activation of the mlp block
130
+ skip_first_layer_pe (bool): skip the PE on the first layer
131
+ """
132
+ super().__init__()
133
+ self.self_attn = Attention(embedding_dim, num_heads)
134
+ self.norm1 = nn.LayerNorm(embedding_dim)
135
+
136
+ self.cross_attn_token_to_image = Attention(
137
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
138
+ )
139
+ self.norm2 = nn.LayerNorm(embedding_dim)
140
+
141
+ self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
142
+ self.norm3 = nn.LayerNorm(embedding_dim)
143
+
144
+ self.norm4 = nn.LayerNorm(embedding_dim)
145
+ self.cross_attn_image_to_token = Attention(
146
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
147
+ )
148
+
149
+ self.skip_first_layer_pe = skip_first_layer_pe
150
+
151
+ def forward(
152
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
153
+ ) -> Tuple[Tensor, Tensor]:
154
+ # Self attention block
155
+ if self.skip_first_layer_pe:
156
+ queries = self.self_attn(q=queries, k=queries, v=queries)
157
+ else:
158
+ q = queries + query_pe
159
+ attn_out = self.self_attn(q=q, k=q, v=queries)
160
+ queries = queries + attn_out
161
+ queries = self.norm1(queries)
162
+
163
+ # Cross attention block, tokens attending to image embedding
164
+ q = queries + query_pe
165
+ k = keys + key_pe
166
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
167
+ queries = queries + attn_out
168
+ queries = self.norm2(queries)
169
+
170
+ # MLP block
171
+ mlp_out = self.mlp(queries)
172
+ queries = queries + mlp_out
173
+ queries = self.norm3(queries)
174
+
175
+ # Cross attention block, image embedding attending to tokens
176
+ q = queries + query_pe
177
+ k = keys + key_pe
178
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
179
+ keys = keys + attn_out
180
+ keys = self.norm4(keys)
181
+
182
+ return queries, keys
183
+
184
+
185
+ class Attention(nn.Module):
186
+ """
187
+ An attention layer that allows for downscaling the size of the embedding
188
+ after projection to queries, keys, and values.
189
+ """
190
+
191
+ def __init__(
192
+ self,
193
+ embedding_dim: int,
194
+ num_heads: int,
195
+ downsample_rate: int = 1,
196
+ ) -> None:
197
+ super().__init__()
198
+ self.embedding_dim = embedding_dim
199
+ self.internal_dim = embedding_dim // downsample_rate
200
+ self.num_heads = num_heads
201
+ assert (
202
+ self.internal_dim % num_heads == 0
203
+ ), "num_heads must divide embedding_dim."
204
+
205
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
206
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
207
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
208
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
209
+
210
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
211
+ b, n, c = x.shape
212
+ x = x.reshape(b, n, num_heads, c // num_heads)
213
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
214
+
215
+ def _recombine_heads(self, x: Tensor) -> Tensor:
216
+ b, n_heads, n_tokens, c_per_head = x.shape
217
+ x = x.transpose(1, 2)
218
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
219
+
220
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
221
+ # Input projections
222
+ q = self.q_proj(q)
223
+ k = self.k_proj(k)
224
+ v = self.v_proj(v)
225
+
226
+ # Separate into heads
227
+ q = self._separate_heads(q, self.num_heads)
228
+ k = self._separate_heads(k, self.num_heads)
229
+ v = self._separate_heads(v, self.num_heads)
230
+
231
+ # Attention
232
+ _, _, _, c_per_head = q.shape
233
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
234
+ attn = attn / math.sqrt(c_per_head)
235
+ attn = torch.softmax(attn, dim=-1)
236
+
237
+ # Get output
238
+ out = attn @ v
239
+ out = self._recombine_heads(out)
240
+ out = self.out_proj(out)
241
+
242
+ return out
lisa_on_cuda/segment_anything/predictor.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional, Tuple
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ from .modeling import Sam
13
+ from .utils.transforms import ResizeLongestSide
14
+
15
+
16
+ class SamPredictor:
17
+ def __init__(
18
+ self,
19
+ sam_model: Sam,
20
+ ) -> None:
21
+ """
22
+ Uses SAM to calculate the image embedding for an image, and then
23
+ allow repeated, efficient mask prediction given prompts.
24
+
25
+ Arguments:
26
+ sam_model (Sam): The model to use for mask prediction.
27
+ """
28
+ super().__init__()
29
+ self.model = sam_model
30
+ self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
31
+ self.reset_image()
32
+
33
+ def set_image(
34
+ self,
35
+ image: np.ndarray,
36
+ image_format: str = "RGB",
37
+ ) -> None:
38
+ """
39
+ Calculates the image embeddings for the provided image, allowing
40
+ masks to be predicted with the 'predict' method.
41
+
42
+ Arguments:
43
+ image (np.ndarray): The image for calculating masks. Expects an
44
+ image in HWC uint8 format, with pixel values in [0, 255].
45
+ image_format (str): The color format of the image, in ['RGB', 'BGR'].
46
+ """
47
+ assert image_format in [
48
+ "RGB",
49
+ "BGR",
50
+ ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
51
+ if image_format != self.model.image_format:
52
+ image = image[..., ::-1]
53
+
54
+ # Transform the image to the form expected by the model
55
+ input_image = self.transform.apply_image(image)
56
+ input_image_torch = torch.as_tensor(input_image, device=self.device)
57
+ input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[
58
+ None, :, :, :
59
+ ]
60
+
61
+ self.set_torch_image(input_image_torch, image.shape[:2])
62
+
63
+ @torch.no_grad()
64
+ def set_torch_image(
65
+ self,
66
+ transformed_image: torch.Tensor,
67
+ original_image_size: Tuple[int, ...],
68
+ ) -> None:
69
+ """
70
+ Calculates the image embeddings for the provided image, allowing
71
+ masks to be predicted with the 'predict' method. Expects the input
72
+ image to be already transformed to the format expected by the model.
73
+
74
+ Arguments:
75
+ transformed_image (torch.Tensor): The input image, with shape
76
+ 1x3xHxW, which has been transformed with ResizeLongestSide.
77
+ original_image_size (tuple(int, int)): The size of the image
78
+ before transformation, in (H, W) format.
79
+ """
80
+ assert (
81
+ len(transformed_image.shape) == 4
82
+ and transformed_image.shape[1] == 3
83
+ and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
84
+ ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
85
+ self.reset_image()
86
+
87
+ self.original_size = original_image_size
88
+ self.input_size = tuple(transformed_image.shape[-2:])
89
+ input_image = self.model.preprocess(transformed_image)
90
+ self.features = self.model.image_encoder(input_image)
91
+ self.is_image_set = True
92
+
93
+ def predict(
94
+ self,
95
+ point_coords: Optional[np.ndarray] = None,
96
+ point_labels: Optional[np.ndarray] = None,
97
+ box: Optional[np.ndarray] = None,
98
+ mask_input: Optional[np.ndarray] = None,
99
+ multimask_output: bool = True,
100
+ return_logits: bool = False,
101
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
102
+ """
103
+ Predict masks for the given input prompts, using the currently set image.
104
+
105
+ Arguments:
106
+ point_coords (np.ndarray or None): A Nx2 array of point prompts to the
107
+ model. Each point is in (X,Y) in pixels.
108
+ point_labels (np.ndarray or None): A length N array of labels for the
109
+ point prompts. 1 indicates a foreground point and 0 indicates a
110
+ background point.
111
+ box (np.ndarray or None): A length 4 array given a box prompt to the
112
+ model, in XYXY format.
113
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
114
+ coming from a previous prediction iteration. Has form 1xHxW, where
115
+ for SAM, H=W=256.
116
+ multimask_output (bool): If true, the model will return three masks.
117
+ For ambiguous input prompts (such as a single click), this will often
118
+ produce better masks than a single prediction. If only a single
119
+ mask is needed, the model's predicted quality score can be used
120
+ to select the best mask. For non-ambiguous prompts, such as multiple
121
+ input prompts, multimask_output=False can give better results.
122
+ return_logits (bool): If true, returns un-thresholded masks logits
123
+ instead of a binary mask.
124
+
125
+ Returns:
126
+ (np.ndarray): The output masks in CxHxW format, where C is the
127
+ number of masks, and (H, W) is the original image size.
128
+ (np.ndarray): An array of length C containing the model's
129
+ predictions for the quality of each mask.
130
+ (np.ndarray): An array of shape CxHxW, where C is the number
131
+ of masks and H=W=256. These low resolution logits can be passed to
132
+ a subsequent iteration as mask input.
133
+ """
134
+ if not self.is_image_set:
135
+ raise RuntimeError(
136
+ "An image must be set with .set_image(...) before mask prediction."
137
+ )
138
+
139
+ # Transform input prompts
140
+ coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
141
+ if point_coords is not None:
142
+ assert (
143
+ point_labels is not None
144
+ ), "point_labels must be supplied if point_coords is supplied."
145
+ point_coords = self.transform.apply_coords(point_coords, self.original_size)
146
+ coords_torch = torch.as_tensor(
147
+ point_coords, dtype=torch.float, device=self.device
148
+ )
149
+ labels_torch = torch.as_tensor(
150
+ point_labels, dtype=torch.int, device=self.device
151
+ )
152
+ coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
153
+ if box is not None:
154
+ box = self.transform.apply_boxes(box, self.original_size)
155
+ box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
156
+ box_torch = box_torch[None, :]
157
+ if mask_input is not None:
158
+ mask_input_torch = torch.as_tensor(
159
+ mask_input, dtype=torch.float, device=self.device
160
+ )
161
+ mask_input_torch = mask_input_torch[None, :, :, :]
162
+
163
+ masks, iou_predictions, low_res_masks = self.predict_torch(
164
+ coords_torch,
165
+ labels_torch,
166
+ box_torch,
167
+ mask_input_torch,
168
+ multimask_output,
169
+ return_logits=return_logits,
170
+ )
171
+
172
+ masks_np = masks[0].detach().cpu().numpy()
173
+ iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
174
+ low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
175
+ return masks_np, iou_predictions_np, low_res_masks_np
176
+
177
+ @torch.no_grad()
178
+ def predict_torch(
179
+ self,
180
+ point_coords: Optional[torch.Tensor],
181
+ point_labels: Optional[torch.Tensor],
182
+ boxes: Optional[torch.Tensor] = None,
183
+ mask_input: Optional[torch.Tensor] = None,
184
+ multimask_output: bool = True,
185
+ return_logits: bool = False,
186
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
187
+ """
188
+ Predict masks for the given input prompts, using the currently set image.
189
+ Input prompts are batched torch tensors and are expected to already be
190
+ transformed to the input frame using ResizeLongestSide.
191
+
192
+ Arguments:
193
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
194
+ model. Each point is in (X,Y) in pixels.
195
+ point_labels (torch.Tensor or None): A BxN array of labels for the
196
+ point prompts. 1 indicates a foreground point and 0 indicates a
197
+ background point.
198
+ boxes (np.ndarray or None): A Bx4 array given a box prompt to the
199
+ model, in XYXY format.
200
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
201
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
202
+ for SAM, H=W=256. Masks returned by a previous iteration of the
203
+ predict method do not need further transformation.
204
+ multimask_output (bool): If true, the model will return three masks.
205
+ For ambiguous input prompts (such as a single click), this will often
206
+ produce better masks than a single prediction. If only a single
207
+ mask is needed, the model's predicted quality score can be used
208
+ to select the best mask. For non-ambiguous prompts, such as multiple
209
+ input prompts, multimask_output=False can give better results.
210
+ return_logits (bool): If true, returns un-thresholded masks logits
211
+ instead of a binary mask.
212
+
213
+ Returns:
214
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
215
+ number of masks, and (H, W) is the original image size.
216
+ (torch.Tensor): An array of shape BxC containing the model's
217
+ predictions for the quality of each mask.
218
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
219
+ of masks and H=W=256. These low res logits can be passed to
220
+ a subsequent iteration as mask input.
221
+ """
222
+ if not self.is_image_set:
223
+ raise RuntimeError(
224
+ "An image must be set with .set_image(...) before mask prediction."
225
+ )
226
+
227
+ if point_coords is not None:
228
+ points = (point_coords, point_labels)
229
+ else:
230
+ points = None
231
+
232
+ # Embed prompts
233
+ sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
234
+ points=points,
235
+ boxes=boxes,
236
+ masks=mask_input,
237
+ )
238
+
239
+ # Predict masks
240
+ low_res_masks, iou_predictions = self.model.mask_decoder(
241
+ image_embeddings=self.features,
242
+ image_pe=self.model.prompt_encoder.get_dense_pe(),
243
+ sparse_prompt_embeddings=sparse_embeddings,
244
+ dense_prompt_embeddings=dense_embeddings,
245
+ multimask_output=multimask_output,
246
+ )
247
+
248
+ # Upscale the masks to the original image resolution
249
+ masks = self.model.postprocess_masks(
250
+ low_res_masks, self.input_size, self.original_size
251
+ )
252
+
253
+ if not return_logits:
254
+ masks = masks > self.model.mask_threshold
255
+
256
+ return masks, iou_predictions, low_res_masks
257
+
258
+ def get_image_embedding(self) -> torch.Tensor:
259
+ """
260
+ Returns the image embeddings for the currently set image, with
261
+ shape 1xCxHxW, where C is the embedding dimension and (H,W) are
262
+ the embedding spatial dimension of SAM (typically C=256, H=W=64).
263
+ """
264
+ if not self.is_image_set:
265
+ raise RuntimeError(
266
+ "An image must be set with .set_image(...) to generate an embedding."
267
+ )
268
+ assert (
269
+ self.features is not None
270
+ ), "Features must exist if an image has been set."
271
+ return self.features
272
+
273
+ @property
274
+ def device(self) -> torch.device:
275
+ return self.model.device
276
+
277
+ def reset_image(self) -> None:
278
+ """Resets the currently set image."""
279
+ self.is_image_set = False
280
+ self.features = None
281
+ self.orig_h = None
282
+ self.orig_w = None
283
+ self.input_h = None
284
+ self.input_w = None
lisa_on_cuda/segment_anything/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
lisa_on_cuda/segment_anything/utils/amg.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from copy import deepcopy
9
+ from itertools import product
10
+ from typing import Any, Dict, Generator, ItemsView, List, Tuple
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+
16
+ class MaskData:
17
+ """
18
+ A structure for storing masks and their related data in batched format.
19
+ Implements basic filtering and concatenation.
20
+ """
21
+
22
+ def __init__(self, **kwargs) -> None:
23
+ for v in kwargs.values():
24
+ assert isinstance(
25
+ v, (list, np.ndarray, torch.Tensor)
26
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
27
+ self._stats = dict(**kwargs)
28
+
29
+ def __setitem__(self, key: str, item: Any) -> None:
30
+ assert isinstance(
31
+ item, (list, np.ndarray, torch.Tensor)
32
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
33
+ self._stats[key] = item
34
+
35
+ def __delitem__(self, key: str) -> None:
36
+ del self._stats[key]
37
+
38
+ def __getitem__(self, key: str) -> Any:
39
+ return self._stats[key]
40
+
41
+ def items(self) -> ItemsView[str, Any]:
42
+ return self._stats.items()
43
+
44
+ def filter(self, keep: torch.Tensor) -> None:
45
+ for k, v in self._stats.items():
46
+ if v is None:
47
+ self._stats[k] = None
48
+ elif isinstance(v, torch.Tensor):
49
+ self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
50
+ elif isinstance(v, np.ndarray):
51
+ self._stats[k] = v[keep.detach().cpu().numpy()]
52
+ elif isinstance(v, list) and keep.dtype == torch.bool:
53
+ self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
54
+ elif isinstance(v, list):
55
+ self._stats[k] = [v[i] for i in keep]
56
+ else:
57
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
58
+
59
+ def cat(self, new_stats: "MaskData") -> None:
60
+ for k, v in new_stats.items():
61
+ if k not in self._stats or self._stats[k] is None:
62
+ self._stats[k] = deepcopy(v)
63
+ elif isinstance(v, torch.Tensor):
64
+ self._stats[k] = torch.cat([self._stats[k], v], dim=0)
65
+ elif isinstance(v, np.ndarray):
66
+ self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
67
+ elif isinstance(v, list):
68
+ self._stats[k] = self._stats[k] + deepcopy(v)
69
+ else:
70
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
71
+
72
+ def to_numpy(self) -> None:
73
+ for k, v in self._stats.items():
74
+ if isinstance(v, torch.Tensor):
75
+ self._stats[k] = v.detach().cpu().numpy()
76
+
77
+
78
+ def is_box_near_crop_edge(
79
+ boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
80
+ ) -> torch.Tensor:
81
+ """Filter masks at the edge of a crop, but not at the edge of the original image."""
82
+ crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
83
+ orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
84
+ boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
85
+ near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
86
+ near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
87
+ near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
88
+ return torch.any(near_crop_edge, dim=1)
89
+
90
+
91
+ def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
92
+ box_xywh = deepcopy(box_xyxy)
93
+ box_xywh[2] = box_xywh[2] - box_xywh[0]
94
+ box_xywh[3] = box_xywh[3] - box_xywh[1]
95
+ return box_xywh
96
+
97
+
98
+ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
99
+ assert len(args) > 0 and all(
100
+ len(a) == len(args[0]) for a in args
101
+ ), "Batched iteration must have inputs of all the same size."
102
+ n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
103
+ for b in range(n_batches):
104
+ yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
105
+
106
+
107
+ def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
108
+ """
109
+ Encodes masks to an uncompressed RLE, in the format expected by
110
+ pycoco tools.
111
+ """
112
+ # Put in fortran order and flatten h,w
113
+ b, h, w = tensor.shape
114
+ tensor = tensor.permute(0, 2, 1).flatten(1)
115
+
116
+ # Compute change indices
117
+ diff = tensor[:, 1:] ^ tensor[:, :-1]
118
+ change_indices = diff.nonzero()
119
+
120
+ # Encode run length
121
+ out = []
122
+ for i in range(b):
123
+ cur_idxs = change_indices[change_indices[:, 0] == i, 1]
124
+ cur_idxs = torch.cat(
125
+ [
126
+ torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
127
+ cur_idxs + 1,
128
+ torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
129
+ ]
130
+ )
131
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
132
+ counts = [] if tensor[i, 0] == 0 else [0]
133
+ counts.extend(btw_idxs.detach().cpu().tolist())
134
+ out.append({"size": [h, w], "counts": counts})
135
+ return out
136
+
137
+
138
+ def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
139
+ """Compute a binary mask from an uncompressed RLE."""
140
+ h, w = rle["size"]
141
+ mask = np.empty(h * w, dtype=bool)
142
+ idx = 0
143
+ parity = False
144
+ for count in rle["counts"]:
145
+ mask[idx : idx + count] = parity
146
+ idx += count
147
+ parity ^= True
148
+ mask = mask.reshape(w, h)
149
+ return mask.transpose() # Put in C order
150
+
151
+
152
+ def area_from_rle(rle: Dict[str, Any]) -> int:
153
+ return sum(rle["counts"][1::2])
154
+
155
+
156
+ def calculate_stability_score(
157
+ masks: torch.Tensor, mask_threshold: float, threshold_offset: float
158
+ ) -> torch.Tensor:
159
+ """
160
+ Computes the stability score for a batch of masks. The stability
161
+ score is the IoU between the binary masks obtained by thresholding
162
+ the predicted mask logits at high and low values.
163
+ """
164
+ # One mask is always contained inside the other.
165
+ # Save memory by preventing unnecessary cast to torch.int64
166
+ intersections = (
167
+ (masks > (mask_threshold + threshold_offset))
168
+ .sum(-1, dtype=torch.int16)
169
+ .sum(-1, dtype=torch.int32)
170
+ )
171
+ unions = (
172
+ (masks > (mask_threshold - threshold_offset))
173
+ .sum(-1, dtype=torch.int16)
174
+ .sum(-1, dtype=torch.int32)
175
+ )
176
+ return intersections / unions
177
+
178
+
179
+ def build_point_grid(n_per_side: int) -> np.ndarray:
180
+ """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
181
+ offset = 1 / (2 * n_per_side)
182
+ points_one_side = np.linspace(offset, 1 - offset, n_per_side)
183
+ points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
184
+ points_y = np.tile(points_one_side[:, None], (1, n_per_side))
185
+ points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
186
+ return points
187
+
188
+
189
+ def build_all_layer_point_grids(
190
+ n_per_side: int, n_layers: int, scale_per_layer: int
191
+ ) -> List[np.ndarray]:
192
+ """Generates point grids for all crop layers."""
193
+ points_by_layer = []
194
+ for i in range(n_layers + 1):
195
+ n_points = int(n_per_side / (scale_per_layer**i))
196
+ points_by_layer.append(build_point_grid(n_points))
197
+ return points_by_layer
198
+
199
+
200
+ def generate_crop_boxes(
201
+ im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
202
+ ) -> Tuple[List[List[int]], List[int]]:
203
+ """
204
+ Generates a list of crop boxes of different sizes. Each layer
205
+ has (2**i)**2 boxes for the ith layer.
206
+ """
207
+ crop_boxes, layer_idxs = [], []
208
+ im_h, im_w = im_size
209
+ short_side = min(im_h, im_w)
210
+
211
+ # Original image
212
+ crop_boxes.append([0, 0, im_w, im_h])
213
+ layer_idxs.append(0)
214
+
215
+ def crop_len(orig_len, n_crops, overlap):
216
+ return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
217
+
218
+ for i_layer in range(n_layers):
219
+ n_crops_per_side = 2 ** (i_layer + 1)
220
+ overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
221
+
222
+ crop_w = crop_len(im_w, n_crops_per_side, overlap)
223
+ crop_h = crop_len(im_h, n_crops_per_side, overlap)
224
+
225
+ crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
226
+ crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
227
+
228
+ # Crops in XYWH format
229
+ for x0, y0 in product(crop_box_x0, crop_box_y0):
230
+ box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
231
+ crop_boxes.append(box)
232
+ layer_idxs.append(i_layer + 1)
233
+
234
+ return crop_boxes, layer_idxs
235
+
236
+
237
+ def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
238
+ x0, y0, _, _ = crop_box
239
+ offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
240
+ # Check if boxes has a channel dimension
241
+ if len(boxes.shape) == 3:
242
+ offset = offset.unsqueeze(1)
243
+ return boxes + offset
244
+
245
+
246
+ def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
247
+ x0, y0, _, _ = crop_box
248
+ offset = torch.tensor([[x0, y0]], device=points.device)
249
+ # Check if points has a channel dimension
250
+ if len(points.shape) == 3:
251
+ offset = offset.unsqueeze(1)
252
+ return points + offset
253
+
254
+
255
+ def uncrop_masks(
256
+ masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
257
+ ) -> torch.Tensor:
258
+ x0, y0, x1, y1 = crop_box
259
+ if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
260
+ return masks
261
+ # Coordinate transform masks
262
+ pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
263
+ pad = (x0, pad_x - x0, y0, pad_y - y0)
264
+ return torch.nn.functional.pad(masks, pad, value=0)
265
+
266
+
267
+ def remove_small_regions(
268
+ mask: np.ndarray, area_thresh: float, mode: str
269
+ ) -> Tuple[np.ndarray, bool]:
270
+ """
271
+ Removes small disconnected regions and holes in a mask. Returns the
272
+ mask and an indicator of if the mask has been modified.
273
+ """
274
+ import cv2 # type: ignore
275
+
276
+ assert mode in ["holes", "islands"]
277
+ correct_holes = mode == "holes"
278
+ working_mask = (correct_holes ^ mask).astype(np.uint8)
279
+ n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
280
+ sizes = stats[:, -1][1:] # Row 0 is background label
281
+ small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
282
+ if len(small_regions) == 0:
283
+ return mask, False
284
+ fill_labels = [0] + small_regions
285
+ if not correct_holes:
286
+ fill_labels = [i for i in range(n_labels) if i not in fill_labels]
287
+ # If every region is below threshold, keep largest
288
+ if len(fill_labels) == 0:
289
+ fill_labels = [int(np.argmax(sizes)) + 1]
290
+ mask = np.isin(regions, fill_labels)
291
+ return mask, True
292
+
293
+
294
+ def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
295
+ from pycocotools import mask as mask_utils # type: ignore
296
+
297
+ h, w = uncompressed_rle["size"]
298
+ rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
299
+ rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
300
+ return rle
301
+
302
+
303
+ def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
304
+ """
305
+ Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
306
+ an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
307
+ """
308
+ # torch.max below raises an error on empty inputs, just skip in this case
309
+ if torch.numel(masks) == 0:
310
+ return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
311
+
312
+ # Normalize shape to CxHxW
313
+ shape = masks.shape
314
+ h, w = shape[-2:]
315
+ if len(shape) > 2:
316
+ masks = masks.flatten(0, -3)
317
+ else:
318
+ masks = masks.unsqueeze(0)
319
+
320
+ # Get top and bottom edges
321
+ in_height, _ = torch.max(masks, dim=-1)
322
+ in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
323
+ bottom_edges, _ = torch.max(in_height_coords, dim=-1)
324
+ in_height_coords = in_height_coords + h * (~in_height)
325
+ top_edges, _ = torch.min(in_height_coords, dim=-1)
326
+
327
+ # Get left and right edges
328
+ in_width, _ = torch.max(masks, dim=-2)
329
+ in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
330
+ right_edges, _ = torch.max(in_width_coords, dim=-1)
331
+ in_width_coords = in_width_coords + w * (~in_width)
332
+ left_edges, _ = torch.min(in_width_coords, dim=-1)
333
+
334
+ # If the mask is empty the right edge will be to the left of the left edge.
335
+ # Replace these boxes with [0, 0, 0, 0]
336
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
337
+ out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
338
+ out = out * (~empty_filter).unsqueeze(-1)
339
+
340
+ # Return to original shape
341
+ if len(shape) > 2:
342
+ out = out.reshape(*shape[:-2], 4)
343
+ else:
344
+ out = out[0]
345
+
346
+ return out
lisa_on_cuda/segment_anything/utils/onnx.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+
13
+ from ..modeling import Sam
14
+ from .amg import calculate_stability_score
15
+
16
+
17
+ class SamOnnxModel(nn.Module):
18
+ """
19
+ This model should not be called directly, but is used in ONNX export.
20
+ It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
21
+ with some functions modified to enable model tracing. Also supports extra
22
+ options controlling what information. See the ONNX export script for details.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ model: Sam,
28
+ return_single_mask: bool,
29
+ use_stability_score: bool = False,
30
+ return_extra_metrics: bool = False,
31
+ ) -> None:
32
+ super().__init__()
33
+ self.mask_decoder = model.mask_decoder
34
+ self.model = model
35
+ self.img_size = model.image_encoder.img_size
36
+ self.return_single_mask = return_single_mask
37
+ self.use_stability_score = use_stability_score
38
+ self.stability_score_offset = 1.0
39
+ self.return_extra_metrics = return_extra_metrics
40
+
41
+ @staticmethod
42
+ def resize_longest_image_size(
43
+ input_image_size: torch.Tensor, longest_side: int
44
+ ) -> torch.Tensor:
45
+ input_image_size = input_image_size.to(torch.float32)
46
+ scale = longest_side / torch.max(input_image_size)
47
+ transformed_size = scale * input_image_size
48
+ transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
49
+ return transformed_size
50
+
51
+ def _embed_points(
52
+ self, point_coords: torch.Tensor, point_labels: torch.Tensor
53
+ ) -> torch.Tensor:
54
+ point_coords = point_coords + 0.5
55
+ point_coords = point_coords / self.img_size
56
+ point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
57
+ point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
58
+
59
+ point_embedding = point_embedding * (point_labels != -1)
60
+ point_embedding = (
61
+ point_embedding
62
+ + self.model.prompt_encoder.not_a_point_embed.weight * (point_labels == -1)
63
+ )
64
+
65
+ for i in range(self.model.prompt_encoder.num_point_embeddings):
66
+ point_embedding = (
67
+ point_embedding
68
+ + self.model.prompt_encoder.point_embeddings[i].weight
69
+ * (point_labels == i)
70
+ )
71
+
72
+ return point_embedding
73
+
74
+ def _embed_masks(
75
+ self, input_mask: torch.Tensor, has_mask_input: torch.Tensor
76
+ ) -> torch.Tensor:
77
+ mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(
78
+ input_mask
79
+ )
80
+ mask_embedding = mask_embedding + (
81
+ 1 - has_mask_input
82
+ ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
83
+ return mask_embedding
84
+
85
+ def mask_postprocessing(
86
+ self, masks: torch.Tensor, orig_im_size: torch.Tensor
87
+ ) -> torch.Tensor:
88
+ masks = F.interpolate(
89
+ masks,
90
+ size=(self.img_size, self.img_size),
91
+ mode="bilinear",
92
+ align_corners=False,
93
+ )
94
+
95
+ prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(
96
+ torch.int64
97
+ )
98
+ masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
99
+
100
+ orig_im_size = orig_im_size.to(torch.int64)
101
+ h, w = orig_im_size[0], orig_im_size[1]
102
+ masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
103
+ return masks
104
+
105
+ def select_masks(
106
+ self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
107
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
108
+ # Determine if we should return the multiclick mask or not from the number of points.
109
+ # The reweighting is used to avoid control flow.
110
+ score_reweight = torch.tensor(
111
+ [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
112
+ ).to(iou_preds.device)
113
+ score = iou_preds + (num_points - 2.5) * score_reweight
114
+ best_idx = torch.argmax(score, dim=1)
115
+ masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
116
+ iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
117
+
118
+ return masks, iou_preds
119
+
120
+ @torch.no_grad()
121
+ def forward(
122
+ self,
123
+ image_embeddings: torch.Tensor,
124
+ point_coords: torch.Tensor,
125
+ point_labels: torch.Tensor,
126
+ mask_input: torch.Tensor,
127
+ has_mask_input: torch.Tensor,
128
+ orig_im_size: torch.Tensor,
129
+ ):
130
+ sparse_embedding = self._embed_points(point_coords, point_labels)
131
+ dense_embedding = self._embed_masks(mask_input, has_mask_input)
132
+
133
+ masks, scores = self.model.mask_decoder.predict_masks(
134
+ image_embeddings=image_embeddings,
135
+ image_pe=self.model.prompt_encoder.get_dense_pe(),
136
+ sparse_prompt_embeddings=sparse_embedding,
137
+ dense_prompt_embeddings=dense_embedding,
138
+ )
139
+
140
+ if self.use_stability_score:
141
+ scores = calculate_stability_score(
142
+ masks, self.model.mask_threshold, self.stability_score_offset
143
+ )
144
+
145
+ if self.return_single_mask:
146
+ masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
147
+
148
+ upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
149
+
150
+ if self.return_extra_metrics:
151
+ stability_scores = calculate_stability_score(
152
+ upscaled_masks, self.model.mask_threshold, self.stability_score_offset
153
+ )
154
+ areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
155
+ return upscaled_masks, scores, stability_scores, areas, masks
156
+
157
+ return upscaled_masks, scores, masks