X-Lai commited on
Commit
c899f8b
·
1 Parent(s): 3d9efe2

fix bug in inference

Browse files

Former-commit-id: 4b1776203d3410cd71b2d4720fc7d9cc61d1db3c

Files changed (5) hide show
  1. app.py +1 -3
  2. chat.py +2 -3
  3. merge_lora_weights_and_save_hf_model.py +0 -0
  4. model/LISA.py +3 -1
  5. train_ds.py +0 -1
app.py CHANGED
@@ -92,7 +92,6 @@ if args.load_in_4bit:
92
  kwargs.update(
93
  {
94
  "torch_dtype": torch.half,
95
- "device_map": "auto",
96
  "load_in_4bit": True,
97
  "quantization_config": BitsAndBytesConfig(
98
  load_in_4bit=True,
@@ -107,7 +106,6 @@ elif args.load_in_8bit:
107
  kwargs.update(
108
  {
109
  "torch_dtype": torch.half,
110
- "device_map": "auto",
111
  "quantization_config": BitsAndBytesConfig(
112
  llm_int8_skip_modules=["visual_model"],
113
  load_in_8bit=True,
@@ -116,7 +114,7 @@ elif args.load_in_8bit:
116
  )
117
 
118
  model = LISAForCausalLM.from_pretrained(
119
- args.version, low_cpu_mem_usage=True, seg_token_idx=args.seg_token_idx, **kwargs
120
  )
121
 
122
  model.config.eos_token_id = tokenizer.eos_token_id
 
92
  kwargs.update(
93
  {
94
  "torch_dtype": torch.half,
 
95
  "load_in_4bit": True,
96
  "quantization_config": BitsAndBytesConfig(
97
  load_in_4bit=True,
 
106
  kwargs.update(
107
  {
108
  "torch_dtype": torch.half,
 
109
  "quantization_config": BitsAndBytesConfig(
110
  llm_int8_skip_modules=["visual_model"],
111
  load_in_8bit=True,
 
114
  )
115
 
116
  model = LISAForCausalLM.from_pretrained(
117
+ args.version, low_cpu_mem_usage=True, vision_tower=args.vision_tower, seg_token_idx=args.seg_token_idx, **kwargs
118
  )
119
 
120
  model.config.eos_token_id = tokenizer.eos_token_id
chat.py CHANGED
@@ -90,7 +90,6 @@ def main(args):
90
  kwargs.update(
91
  {
92
  "torch_dtype": torch.half,
93
- "device_map": "auto",
94
  "load_in_4bit": True,
95
  "quantization_config": BitsAndBytesConfig(
96
  load_in_4bit=True,
@@ -105,7 +104,6 @@ def main(args):
105
  kwargs.update(
106
  {
107
  "torch_dtype": torch.half,
108
- "device_map": "auto",
109
  "quantization_config": BitsAndBytesConfig(
110
  llm_int8_skip_modules=["visual_model"],
111
  load_in_8bit=True,
@@ -114,7 +112,7 @@ def main(args):
114
  )
115
 
116
  model = LISAForCausalLM.from_pretrained(
117
- args.version, low_cpu_mem_usage=True, seg_token_idx=args.seg_token_idx, **kwargs
118
  )
119
 
120
  model.config.eos_token_id = tokenizer.eos_token_id
@@ -223,6 +221,7 @@ def main(args):
223
 
224
  text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
225
  text_output = text_output.replace("\n", "").replace(" ", " ")
 
226
 
227
  for i, pred_mask in enumerate(pred_masks):
228
  if pred_mask.shape[0] == 0:
 
90
  kwargs.update(
91
  {
92
  "torch_dtype": torch.half,
 
93
  "load_in_4bit": True,
94
  "quantization_config": BitsAndBytesConfig(
95
  load_in_4bit=True,
 
104
  kwargs.update(
105
  {
106
  "torch_dtype": torch.half,
 
107
  "quantization_config": BitsAndBytesConfig(
108
  llm_int8_skip_modules=["visual_model"],
109
  load_in_8bit=True,
 
112
  )
113
 
114
  model = LISAForCausalLM.from_pretrained(
115
+ args.version, low_cpu_mem_usage=True, vision_tower=args.vision_tower, seg_token_idx=args.seg_token_idx, **kwargs
116
  )
117
 
118
  model.config.eos_token_id = tokenizer.eos_token_id
 
221
 
222
  text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
223
  text_output = text_output.replace("\n", "").replace(" ", " ")
224
+ print("text_output: ", text_output)
225
 
226
  for i, pred_mask in enumerate(pred_masks):
227
  if pred_mask.shape[0] == 0:
merge_lora_weights_and_save_hf_model.py CHANGED
File without changes
model/LISA.py CHANGED
@@ -134,7 +134,9 @@ class LISAForCausalLM(LlavaLlamaForCausalLM):
134
  self.ce_loss_weight = kwargs.pop("ce_loss_weight", None)
135
  self.dice_loss_weight = kwargs.pop("dice_loss_weight", None)
136
  self.bce_loss_weight = kwargs.pop("bce_loss_weight", None)
137
-
 
 
138
  self.seg_token_idx = kwargs.pop("seg_token_idx")
139
 
140
  super().__init__(config)
 
134
  self.ce_loss_weight = kwargs.pop("ce_loss_weight", None)
135
  self.dice_loss_weight = kwargs.pop("dice_loss_weight", None)
136
  self.bce_loss_weight = kwargs.pop("bce_loss_weight", None)
137
+ else:
138
+ config.mm_vision_tower = config.vision_tower
139
+
140
  self.seg_token_idx = kwargs.pop("seg_token_idx")
141
 
142
  super().__init__(config)
train_ds.py CHANGED
@@ -90,7 +90,6 @@ def parse_args(args):
90
  parser.add_argument("--eval_only", action="store_true", default=False)
91
  parser.add_argument("--vision_pretrained", default="PATH_TO_SAM_ViT-H", type=str)
92
  parser.add_argument("--out_dim", default=256, type=int)
93
- parser.add_argument("--weight", default="", type=str)
94
  parser.add_argument("--resume", default="", type=str)
95
  parser.add_argument("--print_freq", default=1, type=int)
96
  parser.add_argument("--start_epoch", default=0, type=int)
 
90
  parser.add_argument("--eval_only", action="store_true", default=False)
91
  parser.add_argument("--vision_pretrained", default="PATH_TO_SAM_ViT-H", type=str)
92
  parser.add_argument("--out_dim", default=256, type=int)
 
93
  parser.add_argument("--resume", default="", type=str)
94
  parser.add_argument("--print_freq", default=1, type=int)
95
  parser.add_argument("--start_epoch", default=0, type=int)