Yw22 commited on
Commit
bfb88c0
·
1 Parent(s): 9529cd5

[fix] fix mask np.uint8 bug

Browse files
Files changed (2) hide show
  1. app/src/brushedit_app.py +17 -17
  2. app/src/vlm_template.py +3 -3
app/src/brushedit_app.py CHANGED
@@ -528,23 +528,23 @@ def update_vlm_model(vlm_name):
528
  else:
529
  if os.path.exists(vlm_local_path):
530
  vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
531
- vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
532
  else:
533
  if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
534
  vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
535
- vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto")
536
  elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
537
  vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
538
- vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto")
539
  elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
540
  vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
541
- vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto")
542
  elif vlm_name == "llava-v1.6-34b-hf (Preload)":
543
  vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
544
- vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto")
545
  elif vlm_name == "llava-next-72b-hf (Preload)":
546
  vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
547
- vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto")
548
  elif vlm_type == "qwen2-vl":
549
  if vlm_processor != "" and vlm_model != "":
550
  vlm_model.to(device)
@@ -552,17 +552,17 @@ def update_vlm_model(vlm_name):
552
  else:
553
  if os.path.exists(vlm_local_path):
554
  vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path)
555
- vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
556
  else:
557
  if vlm_name == "qwen2-vl-2b-instruct (Preload)":
558
  vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
559
- vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
560
  elif vlm_name == "qwen2-vl-7b-instruct (Preload)":
561
  vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
562
- vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto")
563
  elif vlm_name == "qwen2-vl-72b-instruct (Preload)":
564
  vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")
565
- vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto")
566
  elif vlm_type == "openai":
567
  pass
568
  return "success"
@@ -654,10 +654,10 @@ def process(input_image,
654
  original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
655
  original_image = np.array(original_image)
656
  if input_mask is not None:
657
- input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
658
  input_mask = np.array(input_mask)
659
  if original_mask is not None:
660
- original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
661
  original_mask = np.array(original_mask)
662
  gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
663
  else:
@@ -673,10 +673,10 @@ def process(input_image,
673
  original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
674
  original_image = np.array(original_image)
675
  if input_mask is not None:
676
- input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
677
  input_mask = np.array(input_mask)
678
  if original_mask is not None:
679
- original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
680
  original_mask = np.array(original_mask)
681
 
682
  if invert_mask_state:
@@ -722,7 +722,7 @@ def process(input_image,
722
  sam_predictor,
723
  sam_automask_generator,
724
  groundingdino_model,
725
- device)
726
  except Exception as e:
727
  raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
728
 
@@ -831,9 +831,9 @@ def process_mask(input_image,
831
  sam_predictor,
832
  sam_automask_generator,
833
  groundingdino_model,
834
- device)
835
  else:
836
- original_mask = input_mask
837
  category = None
838
 
839
  ## resize mask if needed
 
528
  else:
529
  if os.path.exists(vlm_local_path):
530
  vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
531
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype=torch_dtype, device_map=device)
532
  else:
533
  if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
534
  vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
535
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch_dtype, device_map=device)
536
  elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
537
  vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
538
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype=torch_dtype, device_map=device)
539
  elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
540
  vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
541
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype=torch_dtype, device_map=device)
542
  elif vlm_name == "llava-v1.6-34b-hf (Preload)":
543
  vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
544
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype=torch_dtype, device_map=device)
545
  elif vlm_name == "llava-next-72b-hf (Preload)":
546
  vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
547
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype=torch_dtype, device_map=device)
548
  elif vlm_type == "qwen2-vl":
549
  if vlm_processor != "" and vlm_model != "":
550
  vlm_model.to(device)
 
552
  else:
553
  if os.path.exists(vlm_local_path):
554
  vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path)
555
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype=torch_dtype, device_map=device)
556
  else:
557
  if vlm_name == "qwen2-vl-2b-instruct (Preload)":
558
  vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
559
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch_dtype, device_map=device)
560
  elif vlm_name == "qwen2-vl-7b-instruct (Preload)":
561
  vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
562
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype=torch_dtype, device_map=device)
563
  elif vlm_name == "qwen2-vl-72b-instruct (Preload)":
564
  vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")
565
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype=torch_dtype, device_map=device)
566
  elif vlm_type == "openai":
567
  pass
568
  return "success"
 
654
  original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
655
  original_image = np.array(original_image)
656
  if input_mask is not None:
657
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask).astype(np.uint8)), target_width=int(output_w), target_height=int(output_h))
658
  input_mask = np.array(input_mask)
659
  if original_mask is not None:
660
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask).astype(np.uint8)), target_width=int(output_w), target_height=int(output_h))
661
  original_mask = np.array(original_mask)
662
  gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
663
  else:
 
673
  original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
674
  original_image = np.array(original_image)
675
  if input_mask is not None:
676
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask).astype(np.uint8)), target_width=int(output_w), target_height=int(output_h))
677
  input_mask = np.array(input_mask)
678
  if original_mask is not None:
679
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask).astype(np.uint8)), target_width=int(output_w), target_height=int(output_h))
680
  original_mask = np.array(original_mask)
681
 
682
  if invert_mask_state:
 
722
  sam_predictor,
723
  sam_automask_generator,
724
  groundingdino_model,
725
+ device).astype(np.uint8)
726
  except Exception as e:
727
  raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
728
 
 
831
  sam_predictor,
832
  sam_automask_generator,
833
  groundingdino_model,
834
+ device).astype(np.uint8)
835
  else:
836
+ original_mask = input_mask.astype(np.uint8)
837
  category = None
838
 
839
  ## resize mask if needed
app/src/vlm_template.py CHANGED
@@ -7,7 +7,7 @@ from transformers import (
7
  Qwen2VLForConditionalGeneration, Qwen2VLProcessor
8
  )
9
  ## init device
10
- device = "cpu"
11
  torch_dtype = torch.float16
12
 
13
 
@@ -103,10 +103,10 @@ vlms_list = [
103
  ),
104
  "model": Qwen2VLForConditionalGeneration.from_pretrained(
105
  "models/vlms/Qwen2-VL-7B-Instruct", torch_dtype=torch_dtype, device_map=device
106
- ).to("cpu") if os.path.exists("models/vlms/Qwen2-VL-7B-Instruct") else
107
  Qwen2VLForConditionalGeneration.from_pretrained(
108
  "Qwen/Qwen2-VL-7B-Instruct", torch_dtype=torch_dtype, device_map=device
109
- ).to("cpu"),
110
  },
111
  {
112
  "type": "openai",
 
7
  Qwen2VLForConditionalGeneration, Qwen2VLProcessor
8
  )
9
  ## init device
10
+ device = "cuda"
11
  torch_dtype = torch.float16
12
 
13
 
 
103
  ),
104
  "model": Qwen2VLForConditionalGeneration.from_pretrained(
105
  "models/vlms/Qwen2-VL-7B-Instruct", torch_dtype=torch_dtype, device_map=device
106
+ ).to(device) if os.path.exists("models/vlms/Qwen2-VL-7B-Instruct") else
107
  Qwen2VLForConditionalGeneration.from_pretrained(
108
  "Qwen/Qwen2-VL-7B-Instruct", torch_dtype=torch_dtype, device_map=device
109
+ ).to(device),
110
  },
111
  {
112
  "type": "openai",