Spaces:
Paused
Paused
X-Lai
commited on
Commit
·
c899f8b
1
Parent(s):
3d9efe2
fix bug in inference
Browse filesFormer-commit-id: 4b1776203d3410cd71b2d4720fc7d9cc61d1db3c
- app.py +1 -3
- chat.py +2 -3
- merge_lora_weights_and_save_hf_model.py +0 -0
- model/LISA.py +3 -1
- train_ds.py +0 -1
app.py
CHANGED
@@ -92,7 +92,6 @@ if args.load_in_4bit:
|
|
92 |
kwargs.update(
|
93 |
{
|
94 |
"torch_dtype": torch.half,
|
95 |
-
"device_map": "auto",
|
96 |
"load_in_4bit": True,
|
97 |
"quantization_config": BitsAndBytesConfig(
|
98 |
load_in_4bit=True,
|
@@ -107,7 +106,6 @@ elif args.load_in_8bit:
|
|
107 |
kwargs.update(
|
108 |
{
|
109 |
"torch_dtype": torch.half,
|
110 |
-
"device_map": "auto",
|
111 |
"quantization_config": BitsAndBytesConfig(
|
112 |
llm_int8_skip_modules=["visual_model"],
|
113 |
load_in_8bit=True,
|
@@ -116,7 +114,7 @@ elif args.load_in_8bit:
|
|
116 |
)
|
117 |
|
118 |
model = LISAForCausalLM.from_pretrained(
|
119 |
-
args.version, low_cpu_mem_usage=True, seg_token_idx=args.seg_token_idx, **kwargs
|
120 |
)
|
121 |
|
122 |
model.config.eos_token_id = tokenizer.eos_token_id
|
|
|
92 |
kwargs.update(
|
93 |
{
|
94 |
"torch_dtype": torch.half,
|
|
|
95 |
"load_in_4bit": True,
|
96 |
"quantization_config": BitsAndBytesConfig(
|
97 |
load_in_4bit=True,
|
|
|
106 |
kwargs.update(
|
107 |
{
|
108 |
"torch_dtype": torch.half,
|
|
|
109 |
"quantization_config": BitsAndBytesConfig(
|
110 |
llm_int8_skip_modules=["visual_model"],
|
111 |
load_in_8bit=True,
|
|
|
114 |
)
|
115 |
|
116 |
model = LISAForCausalLM.from_pretrained(
|
117 |
+
args.version, low_cpu_mem_usage=True, vision_tower=args.vision_tower, seg_token_idx=args.seg_token_idx, **kwargs
|
118 |
)
|
119 |
|
120 |
model.config.eos_token_id = tokenizer.eos_token_id
|
chat.py
CHANGED
@@ -90,7 +90,6 @@ def main(args):
|
|
90 |
kwargs.update(
|
91 |
{
|
92 |
"torch_dtype": torch.half,
|
93 |
-
"device_map": "auto",
|
94 |
"load_in_4bit": True,
|
95 |
"quantization_config": BitsAndBytesConfig(
|
96 |
load_in_4bit=True,
|
@@ -105,7 +104,6 @@ def main(args):
|
|
105 |
kwargs.update(
|
106 |
{
|
107 |
"torch_dtype": torch.half,
|
108 |
-
"device_map": "auto",
|
109 |
"quantization_config": BitsAndBytesConfig(
|
110 |
llm_int8_skip_modules=["visual_model"],
|
111 |
load_in_8bit=True,
|
@@ -114,7 +112,7 @@ def main(args):
|
|
114 |
)
|
115 |
|
116 |
model = LISAForCausalLM.from_pretrained(
|
117 |
-
args.version, low_cpu_mem_usage=True, seg_token_idx=args.seg_token_idx, **kwargs
|
118 |
)
|
119 |
|
120 |
model.config.eos_token_id = tokenizer.eos_token_id
|
@@ -223,6 +221,7 @@ def main(args):
|
|
223 |
|
224 |
text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
|
225 |
text_output = text_output.replace("\n", "").replace(" ", " ")
|
|
|
226 |
|
227 |
for i, pred_mask in enumerate(pred_masks):
|
228 |
if pred_mask.shape[0] == 0:
|
|
|
90 |
kwargs.update(
|
91 |
{
|
92 |
"torch_dtype": torch.half,
|
|
|
93 |
"load_in_4bit": True,
|
94 |
"quantization_config": BitsAndBytesConfig(
|
95 |
load_in_4bit=True,
|
|
|
104 |
kwargs.update(
|
105 |
{
|
106 |
"torch_dtype": torch.half,
|
|
|
107 |
"quantization_config": BitsAndBytesConfig(
|
108 |
llm_int8_skip_modules=["visual_model"],
|
109 |
load_in_8bit=True,
|
|
|
112 |
)
|
113 |
|
114 |
model = LISAForCausalLM.from_pretrained(
|
115 |
+
args.version, low_cpu_mem_usage=True, vision_tower=args.vision_tower, seg_token_idx=args.seg_token_idx, **kwargs
|
116 |
)
|
117 |
|
118 |
model.config.eos_token_id = tokenizer.eos_token_id
|
|
|
221 |
|
222 |
text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
|
223 |
text_output = text_output.replace("\n", "").replace(" ", " ")
|
224 |
+
print("text_output: ", text_output)
|
225 |
|
226 |
for i, pred_mask in enumerate(pred_masks):
|
227 |
if pred_mask.shape[0] == 0:
|
merge_lora_weights_and_save_hf_model.py
CHANGED
File without changes
|
model/LISA.py
CHANGED
@@ -134,7 +134,9 @@ class LISAForCausalLM(LlavaLlamaForCausalLM):
|
|
134 |
self.ce_loss_weight = kwargs.pop("ce_loss_weight", None)
|
135 |
self.dice_loss_weight = kwargs.pop("dice_loss_weight", None)
|
136 |
self.bce_loss_weight = kwargs.pop("bce_loss_weight", None)
|
137 |
-
|
|
|
|
|
138 |
self.seg_token_idx = kwargs.pop("seg_token_idx")
|
139 |
|
140 |
super().__init__(config)
|
|
|
134 |
self.ce_loss_weight = kwargs.pop("ce_loss_weight", None)
|
135 |
self.dice_loss_weight = kwargs.pop("dice_loss_weight", None)
|
136 |
self.bce_loss_weight = kwargs.pop("bce_loss_weight", None)
|
137 |
+
else:
|
138 |
+
config.mm_vision_tower = config.vision_tower
|
139 |
+
|
140 |
self.seg_token_idx = kwargs.pop("seg_token_idx")
|
141 |
|
142 |
super().__init__(config)
|
train_ds.py
CHANGED
@@ -90,7 +90,6 @@ def parse_args(args):
|
|
90 |
parser.add_argument("--eval_only", action="store_true", default=False)
|
91 |
parser.add_argument("--vision_pretrained", default="PATH_TO_SAM_ViT-H", type=str)
|
92 |
parser.add_argument("--out_dim", default=256, type=int)
|
93 |
-
parser.add_argument("--weight", default="", type=str)
|
94 |
parser.add_argument("--resume", default="", type=str)
|
95 |
parser.add_argument("--print_freq", default=1, type=int)
|
96 |
parser.add_argument("--start_epoch", default=0, type=int)
|
|
|
90 |
parser.add_argument("--eval_only", action="store_true", default=False)
|
91 |
parser.add_argument("--vision_pretrained", default="PATH_TO_SAM_ViT-H", type=str)
|
92 |
parser.add_argument("--out_dim", default=256, type=int)
|
|
|
93 |
parser.add_argument("--resume", default="", type=str)
|
94 |
parser.add_argument("--print_freq", default=1, type=int)
|
95 |
parser.add_argument("--start_epoch", default=0, type=int)
|