John6666 commited on
Commit
bfd1b89
·
verified ·
1 Parent(s): 59e10a3

Upload 2 files

Browse files
Files changed (1) hide show
  1. joycaption.py +6 -122
joycaption.py CHANGED
@@ -33,8 +33,9 @@ use_inference_client = False
33
  PIXTRAL_PATH = "mistral-community/pixtral-12b"
34
 
35
  llm_models = {
36
- "bunnycore/LLama-3.1-8B-Matrix": None,
37
  #PIXTRAL_PATH: None,
 
38
  "Sao10K/Llama-3.1-8B-Stheno-v3.4": None,
39
  "unsloth/Meta-Llama-3.1-8B-bnb-4bit": None,
40
  "DevQuasar/HermesNova-Llama-3.1-8B": None,
@@ -157,6 +158,8 @@ def load_text_model(model_name: str=MODEL_PATH, gguf_file: Union[str, None]=None
157
  else:
158
  text_model = LlavaForConditionalGeneration.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16).eval()
159
  image_adapter = AutoProcessor.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16)
 
 
160
 
161
  print("Loading tokenizer")
162
  if gguf_file: tokenizer = AutoTokenizer.from_pretrained(model_name, gguf_file=gguf_file, use_fast=True, legacy=False)
@@ -217,88 +220,10 @@ clip_model.eval().requires_grad_(False).to(device)
217
  load_text_model()
218
 
219
  @spaces.GPU()
220
- @torch.no_grad()
221
- def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: Union[str, int]) -> str:
222
- torch.cuda.empty_cache()
223
-
224
- # 'any' means no length specified
225
- length = None if caption_length == "any" else caption_length
226
-
227
- if isinstance(length, str):
228
- try:
229
- length = int(length)
230
- except ValueError:
231
- pass
232
-
233
- # 'rng-tags' and 'training_prompt' don't have formal/informal tones
234
- if caption_type == "rng-tags" or caption_type == "training_prompt":
235
- caption_tone = "formal"
236
-
237
- # Build prompt
238
- prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int))
239
- if prompt_key not in CAPTION_TYPE_MAP:
240
- raise ValueError(f"Invalid caption type: {prompt_key}")
241
-
242
- prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(length=length, word_count=length)
243
- print(f"Prompt: {prompt_str}")
244
-
245
- # Preprocess image
246
- #image = clip_processor(images=input_image, return_tensors='pt').pixel_values
247
- image = input_image.resize((384, 384), Image.LANCZOS)
248
- pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
249
- pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
250
- pixel_values = pixel_values.to('cuda')
251
-
252
- # Tokenize the prompt
253
- prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
254
-
255
- # Embed image
256
- with torch.amp.autocast_mode.autocast('cuda', enabled=True):
257
- vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
258
- image_features = vision_outputs.hidden_states
259
- embedded_images = image_adapter(image_features)
260
- embedded_images = embedded_images.to('cuda')
261
-
262
- # Embed prompt
263
- prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
264
- assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
265
- embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
266
- eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype)
267
-
268
- # Construct prompts
269
- inputs_embeds = torch.cat([
270
- embedded_bos.expand(embedded_images.shape[0], -1, -1),
271
- embedded_images.to(dtype=embedded_bos.dtype),
272
- prompt_embeds.expand(embedded_images.shape[0], -1, -1),
273
- eot_embed.expand(embedded_images.shape[0], -1, -1),
274
- ], dim=1)
275
-
276
- input_ids = torch.cat([
277
- torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long),
278
- torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
279
- prompt,
280
- torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long),
281
- ], dim=1).to('cuda')
282
- attention_mask = torch.ones_like(input_ids)
283
-
284
- #generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=False, suppress_tokens=None)
285
- #generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, top_k=10, temperature=0.5, suppress_tokens=None)
286
- generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, suppress_tokens=None) # Uses the default which is temp=0.6, top_p=0.9
287
-
288
- # Trim off the prompt
289
- generate_ids = generate_ids[:, input_ids.shape[1]:]
290
- if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
291
- generate_ids = generate_ids[:, :-1]
292
-
293
- caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
294
-
295
- return caption.strip()
296
-
297
- @spaces.GPU()
298
- @torch.no_grad()
299
  def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: Union[str, int],
300
  max_new_tokens: int=300, top_p: float=0.9, temperature: float=0.6, model_name: str=MODEL_PATH, progress=gr.Progress(track_tqdm=True)) -> str:
301
- global use_inference_client, text_model
302
  torch.cuda.empty_cache()
303
  gc.collect()
304
 
@@ -476,44 +401,3 @@ def change_text_model(model_name: str=MODEL_PATH, use_client: bool=False, gguf_f
476
  return gr.update(choices=get_text_model())
477
  except Exception as e:
478
  raise gr.Error(f"Model load error: {model_name}, {e}")
479
-
480
-
481
- # original UI
482
- with gr.Blocks() as demo:
483
- gr.HTML(TITLE)
484
-
485
- with gr.Row():
486
- with gr.Column():
487
- input_image = gr.Image(type="pil", label="Input Image")
488
-
489
- caption_type = gr.Dropdown(
490
- choices=["descriptive", "training_prompt", "rng-tags"],
491
- label="Caption Type",
492
- value="descriptive",
493
- )
494
-
495
- caption_tone = gr.Dropdown(
496
- choices=["formal", "informal"],
497
- label="Caption Tone",
498
- value="formal",
499
- )
500
-
501
- caption_length = gr.Dropdown(
502
- choices=["any", "very short", "short", "medium-length", "long", "very long"] +
503
- [str(i) for i in range(20, 261, 10)],
504
- label="Caption Length",
505
- value="any",
506
- )
507
-
508
- gr.Markdown("**Note:** Caption tone doesn't affect `rng-tags` and `training_prompt`.")
509
-
510
- run_button = gr.Button("Caption")
511
-
512
- with gr.Column():
513
- output_caption = gr.Textbox(label="Caption")
514
-
515
- run_button.click(fn=stream_chat, inputs=[input_image, caption_type, caption_tone, caption_length], outputs=[output_caption])
516
-
517
-
518
- if __name__ == "__main__":
519
- demo.launch()
 
33
  PIXTRAL_PATH = "mistral-community/pixtral-12b"
34
 
35
  llm_models = {
36
+ "Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2": None,
37
  #PIXTRAL_PATH: None,
38
+ "bunnycore/LLama-3.1-8B-Matrix": None,
39
  "Sao10K/Llama-3.1-8B-Stheno-v3.4": None,
40
  "unsloth/Meta-Llama-3.1-8B-bnb-4bit": None,
41
  "DevQuasar/HermesNova-Llama-3.1-8B": None,
 
158
  else:
159
  text_model = LlavaForConditionalGeneration.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16).eval()
160
  image_adapter = AutoProcessor.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16)
161
+ tokenizer = None
162
+ peft_config = None
163
 
164
  print("Loading tokenizer")
165
  if gguf_file: tokenizer = AutoTokenizer.from_pretrained(model_name, gguf_file=gguf_file, use_fast=True, legacy=False)
 
220
  load_text_model()
221
 
222
  @spaces.GPU()
223
+ @torch.inference_mode()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: Union[str, int],
225
  max_new_tokens: int=300, top_p: float=0.9, temperature: float=0.6, model_name: str=MODEL_PATH, progress=gr.Progress(track_tqdm=True)) -> str:
226
+ global tokenizer, text_model, image_adapter, peft_config, text_model_client, use_inference_client
227
  torch.cuda.empty_cache()
228
  gc.collect()
229
 
 
401
  return gr.update(choices=get_text_model())
402
  except Exception as e:
403
  raise gr.Error(f"Model load error: {model_name}, {e}")