John6666 commited on
Commit
59e10a3
·
verified ·
1 Parent(s): cd64e85

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. joycaption.py +21 -6
app.py CHANGED
@@ -49,14 +49,14 @@ with gr.Blocks(fill_width=True, css=css, delete_cache=(60, 3600)) as demo:
49
  jc_temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.6, step=0.1, label="Temperature")
50
  jc_topp = gr.Slider(minimum=0, maximum=2.0, value=0.9, step=0.01, label="Top-P")
51
  jc_run_button = gr.Button("Caption", variant="primary")
52
-
53
  with gr.Column():
54
  jc_output_caption = gr.Textbox(label="Caption", show_copy_button=True)
55
  gr.Markdown(JC_DESC_MD, elem_classes="info")
56
  gr.LoginButton()
57
  gr.DuplicateButton(value="Duplicate Space for private use (This demo does not work on CPU. Requires GPU Space)")
58
 
59
- jc_run_button.click(fn=stream_chat_mod, inputs=[jc_input_image, jc_caption_type, jc_caption_tone, jc_caption_length, jc_tokens, jc_topp, jc_temperature], outputs=[jc_output_caption])
 
60
  jc_text_model_button.click(change_text_model, [jc_text_model, jc_use_inference_client, jc_gguf, jc_nf4], [jc_text_model], show_api=False)
61
  #jc_text_model.change(get_repo_gguf, [jc_text_model], [jc_gguf], show_api=False)
62
  jc_use_inference_client.change(change_text_model, [jc_text_model, jc_use_inference_client], [jc_text_model], show_api=False)
 
49
  jc_temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.6, step=0.1, label="Temperature")
50
  jc_topp = gr.Slider(minimum=0, maximum=2.0, value=0.9, step=0.01, label="Top-P")
51
  jc_run_button = gr.Button("Caption", variant="primary")
 
52
  with gr.Column():
53
  jc_output_caption = gr.Textbox(label="Caption", show_copy_button=True)
54
  gr.Markdown(JC_DESC_MD, elem_classes="info")
55
  gr.LoginButton()
56
  gr.DuplicateButton(value="Duplicate Space for private use (This demo does not work on CPU. Requires GPU Space)")
57
 
58
+ jc_run_button.click(fn=stream_chat_mod, inputs=[jc_input_image, jc_caption_type, jc_caption_tone, jc_caption_length,
59
+ jc_tokens, jc_topp, jc_temperature, jc_text_model], outputs=[jc_output_caption])
60
  jc_text_model_button.click(change_text_model, [jc_text_model, jc_use_inference_client, jc_gguf, jc_nf4], [jc_text_model], show_api=False)
61
  #jc_text_model.change(get_repo_gguf, [jc_text_model], [jc_gguf], show_api=False)
62
  jc_use_inference_client.change(change_text_model, [jc_text_model, jc_use_inference_client], [jc_text_model], show_api=False)
joycaption.py CHANGED
@@ -30,9 +30,11 @@ BASE_DIR = Path(__file__).resolve().parent
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
32
  use_inference_client = False
 
33
 
34
  llm_models = {
35
  "bunnycore/LLama-3.1-8B-Matrix": None,
 
36
  "Sao10K/Llama-3.1-8B-Stheno-v3.4": None,
37
  "unsloth/Meta-Llama-3.1-8B-bnb-4bit": None,
38
  "DevQuasar/HermesNova-Llama-3.1-8B": None,
@@ -123,7 +125,6 @@ class ImageAdapter(nn.Module):
123
  def get_eot_embedding(self):
124
  return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
125
 
126
-
127
  # https://huggingface.co/docs/transformers/v4.44.2/gguf
128
  # https://github.com/city96/ComfyUI-GGUF/issues/7
129
  # https://github.com/THUDM/ChatGLM-6B/issues/18
@@ -147,6 +148,15 @@ def load_text_model(model_name: str=MODEL_PATH, gguf_file: Union[str, None]=None
147
  from transformers import BitsAndBytesConfig
148
  nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
149
  bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
 
 
 
 
 
 
 
 
 
150
 
151
  print("Loading tokenizer")
152
  if gguf_file: tokenizer = AutoTokenizer.from_pretrained(model_name, gguf_file=gguf_file, use_fast=True, legacy=False)
@@ -286,7 +296,8 @@ def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str,
286
 
287
  @spaces.GPU()
288
  @torch.no_grad()
289
- def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: Union[str, int], max_new_tokens: int=300, top_p: float=0.9, temperature: float=0.6, progress=gr.Progress(track_tqdm=True)) -> str:
 
290
  global use_inference_client, text_model
291
  torch.cuda.empty_cache()
292
  gc.collect()
@@ -312,8 +323,15 @@ def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_tone: s
312
  prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(length=length, word_count=length)
313
  print(f"Prompt: {prompt_str}")
314
 
 
 
 
 
 
 
 
 
315
  # Preprocess image
316
- #image = clip_processor(images=input_image, return_tensors='pt').pixel_values
317
  image = input_image.resize((384, 384), Image.LANCZOS)
318
  pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
319
  pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
@@ -352,9 +370,6 @@ def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_tone: s
352
  attention_mask = torch.ones_like(input_ids)
353
 
354
  text_model.to(device)
355
- #generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=False, suppress_tokens=None)
356
- #generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, top_k=10, temperature=0.5, suppress_tokens=None)
357
- #generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, suppress_tokens=None) # Uses the default which is temp=0.6, top_p=0.9
358
  generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=max_new_tokens,
359
  do_sample=True, suppress_tokens=None, top_p=top_p, temperature=temperature)
360
 
 
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
32
  use_inference_client = False
33
+ PIXTRAL_PATH = "mistral-community/pixtral-12b"
34
 
35
  llm_models = {
36
  "bunnycore/LLama-3.1-8B-Matrix": None,
37
+ #PIXTRAL_PATH: None,
38
  "Sao10K/Llama-3.1-8B-Stheno-v3.4": None,
39
  "unsloth/Meta-Llama-3.1-8B-bnb-4bit": None,
40
  "DevQuasar/HermesNova-Llama-3.1-8B": None,
 
125
  def get_eot_embedding(self):
126
  return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
127
 
 
128
  # https://huggingface.co/docs/transformers/v4.44.2/gguf
129
  # https://github.com/city96/ComfyUI-GGUF/issues/7
130
  # https://github.com/THUDM/ChatGLM-6B/issues/18
 
148
  from transformers import BitsAndBytesConfig
149
  nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
150
  bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
151
+
152
+ if model_name == PIXTRAL_PATH:
153
+ from transformers import AutoProcessor, LlavaForConditionalGeneration
154
+ if is_nf4:
155
+ text_model = LlavaForConditionalGeneration.from_pretrained(model_name, quantization_config=nf4_config, device_map=device, torch_dtype=torch.bfloat16).eval()
156
+ image_adapter = AutoProcessor.from_pretrained(model_name, quantization_config=nf4_config, device_map=device, torch_dtype=torch.bfloat16)
157
+ else:
158
+ text_model = LlavaForConditionalGeneration.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16).eval()
159
+ image_adapter = AutoProcessor.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16)
160
 
161
  print("Loading tokenizer")
162
  if gguf_file: tokenizer = AutoTokenizer.from_pretrained(model_name, gguf_file=gguf_file, use_fast=True, legacy=False)
 
296
 
297
  @spaces.GPU()
298
  @torch.no_grad()
299
+ def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: Union[str, int],
300
+ max_new_tokens: int=300, top_p: float=0.9, temperature: float=0.6, model_name: str=MODEL_PATH, progress=gr.Progress(track_tqdm=True)) -> str:
301
  global use_inference_client, text_model
302
  torch.cuda.empty_cache()
303
  gc.collect()
 
323
  prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(length=length, word_count=length)
324
  print(f"Prompt: {prompt_str}")
325
 
326
+ # Pixtral
327
+ if model_name == PIXTRAL_PATH:
328
+ input_images = [input_image]
329
+ inputs = image_adapter(text=prompt_str, images=input_images, return_tensors="pt").to(device)
330
+ generate_ids = text_model.generate(**inputs, max_new_tokens=max_new_tokens)
331
+ output = image_adapter.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
332
+ return output.strip()
333
+
334
  # Preprocess image
 
335
  image = input_image.resize((384, 384), Image.LANCZOS)
336
  pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
337
  pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
 
370
  attention_mask = torch.ones_like(input_ids)
371
 
372
  text_model.to(device)
 
 
 
373
  generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=max_new_tokens,
374
  do_sample=True, suppress_tokens=None, top_p=top_p, temperature=temperature)
375