Spaces:
Running
on
Zero
Running
on
Zero
Upload 2 files
Browse files- joycaption.py +6 -122
joycaption.py
CHANGED
@@ -33,8 +33,9 @@ use_inference_client = False
|
|
33 |
PIXTRAL_PATH = "mistral-community/pixtral-12b"
|
34 |
|
35 |
llm_models = {
|
36 |
-
"
|
37 |
#PIXTRAL_PATH: None,
|
|
|
38 |
"Sao10K/Llama-3.1-8B-Stheno-v3.4": None,
|
39 |
"unsloth/Meta-Llama-3.1-8B-bnb-4bit": None,
|
40 |
"DevQuasar/HermesNova-Llama-3.1-8B": None,
|
@@ -157,6 +158,8 @@ def load_text_model(model_name: str=MODEL_PATH, gguf_file: Union[str, None]=None
|
|
157 |
else:
|
158 |
text_model = LlavaForConditionalGeneration.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16).eval()
|
159 |
image_adapter = AutoProcessor.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16)
|
|
|
|
|
160 |
|
161 |
print("Loading tokenizer")
|
162 |
if gguf_file: tokenizer = AutoTokenizer.from_pretrained(model_name, gguf_file=gguf_file, use_fast=True, legacy=False)
|
@@ -217,88 +220,10 @@ clip_model.eval().requires_grad_(False).to(device)
|
|
217 |
load_text_model()
|
218 |
|
219 |
@spaces.GPU()
|
220 |
-
@torch.
|
221 |
-
def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: Union[str, int]) -> str:
|
222 |
-
torch.cuda.empty_cache()
|
223 |
-
|
224 |
-
# 'any' means no length specified
|
225 |
-
length = None if caption_length == "any" else caption_length
|
226 |
-
|
227 |
-
if isinstance(length, str):
|
228 |
-
try:
|
229 |
-
length = int(length)
|
230 |
-
except ValueError:
|
231 |
-
pass
|
232 |
-
|
233 |
-
# 'rng-tags' and 'training_prompt' don't have formal/informal tones
|
234 |
-
if caption_type == "rng-tags" or caption_type == "training_prompt":
|
235 |
-
caption_tone = "formal"
|
236 |
-
|
237 |
-
# Build prompt
|
238 |
-
prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int))
|
239 |
-
if prompt_key not in CAPTION_TYPE_MAP:
|
240 |
-
raise ValueError(f"Invalid caption type: {prompt_key}")
|
241 |
-
|
242 |
-
prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(length=length, word_count=length)
|
243 |
-
print(f"Prompt: {prompt_str}")
|
244 |
-
|
245 |
-
# Preprocess image
|
246 |
-
#image = clip_processor(images=input_image, return_tensors='pt').pixel_values
|
247 |
-
image = input_image.resize((384, 384), Image.LANCZOS)
|
248 |
-
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
|
249 |
-
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
|
250 |
-
pixel_values = pixel_values.to('cuda')
|
251 |
-
|
252 |
-
# Tokenize the prompt
|
253 |
-
prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
|
254 |
-
|
255 |
-
# Embed image
|
256 |
-
with torch.amp.autocast_mode.autocast('cuda', enabled=True):
|
257 |
-
vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
|
258 |
-
image_features = vision_outputs.hidden_states
|
259 |
-
embedded_images = image_adapter(image_features)
|
260 |
-
embedded_images = embedded_images.to('cuda')
|
261 |
-
|
262 |
-
# Embed prompt
|
263 |
-
prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
|
264 |
-
assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
|
265 |
-
embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
|
266 |
-
eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype)
|
267 |
-
|
268 |
-
# Construct prompts
|
269 |
-
inputs_embeds = torch.cat([
|
270 |
-
embedded_bos.expand(embedded_images.shape[0], -1, -1),
|
271 |
-
embedded_images.to(dtype=embedded_bos.dtype),
|
272 |
-
prompt_embeds.expand(embedded_images.shape[0], -1, -1),
|
273 |
-
eot_embed.expand(embedded_images.shape[0], -1, -1),
|
274 |
-
], dim=1)
|
275 |
-
|
276 |
-
input_ids = torch.cat([
|
277 |
-
torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long),
|
278 |
-
torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
|
279 |
-
prompt,
|
280 |
-
torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long),
|
281 |
-
], dim=1).to('cuda')
|
282 |
-
attention_mask = torch.ones_like(input_ids)
|
283 |
-
|
284 |
-
#generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=False, suppress_tokens=None)
|
285 |
-
#generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, top_k=10, temperature=0.5, suppress_tokens=None)
|
286 |
-
generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, suppress_tokens=None) # Uses the default which is temp=0.6, top_p=0.9
|
287 |
-
|
288 |
-
# Trim off the prompt
|
289 |
-
generate_ids = generate_ids[:, input_ids.shape[1]:]
|
290 |
-
if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
|
291 |
-
generate_ids = generate_ids[:, :-1]
|
292 |
-
|
293 |
-
caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
|
294 |
-
|
295 |
-
return caption.strip()
|
296 |
-
|
297 |
-
@spaces.GPU()
|
298 |
-
@torch.no_grad()
|
299 |
def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: Union[str, int],
|
300 |
max_new_tokens: int=300, top_p: float=0.9, temperature: float=0.6, model_name: str=MODEL_PATH, progress=gr.Progress(track_tqdm=True)) -> str:
|
301 |
-
global
|
302 |
torch.cuda.empty_cache()
|
303 |
gc.collect()
|
304 |
|
@@ -476,44 +401,3 @@ def change_text_model(model_name: str=MODEL_PATH, use_client: bool=False, gguf_f
|
|
476 |
return gr.update(choices=get_text_model())
|
477 |
except Exception as e:
|
478 |
raise gr.Error(f"Model load error: {model_name}, {e}")
|
479 |
-
|
480 |
-
|
481 |
-
# original UI
|
482 |
-
with gr.Blocks() as demo:
|
483 |
-
gr.HTML(TITLE)
|
484 |
-
|
485 |
-
with gr.Row():
|
486 |
-
with gr.Column():
|
487 |
-
input_image = gr.Image(type="pil", label="Input Image")
|
488 |
-
|
489 |
-
caption_type = gr.Dropdown(
|
490 |
-
choices=["descriptive", "training_prompt", "rng-tags"],
|
491 |
-
label="Caption Type",
|
492 |
-
value="descriptive",
|
493 |
-
)
|
494 |
-
|
495 |
-
caption_tone = gr.Dropdown(
|
496 |
-
choices=["formal", "informal"],
|
497 |
-
label="Caption Tone",
|
498 |
-
value="formal",
|
499 |
-
)
|
500 |
-
|
501 |
-
caption_length = gr.Dropdown(
|
502 |
-
choices=["any", "very short", "short", "medium-length", "long", "very long"] +
|
503 |
-
[str(i) for i in range(20, 261, 10)],
|
504 |
-
label="Caption Length",
|
505 |
-
value="any",
|
506 |
-
)
|
507 |
-
|
508 |
-
gr.Markdown("**Note:** Caption tone doesn't affect `rng-tags` and `training_prompt`.")
|
509 |
-
|
510 |
-
run_button = gr.Button("Caption")
|
511 |
-
|
512 |
-
with gr.Column():
|
513 |
-
output_caption = gr.Textbox(label="Caption")
|
514 |
-
|
515 |
-
run_button.click(fn=stream_chat, inputs=[input_image, caption_type, caption_tone, caption_length], outputs=[output_caption])
|
516 |
-
|
517 |
-
|
518 |
-
if __name__ == "__main__":
|
519 |
-
demo.launch()
|
|
|
33 |
PIXTRAL_PATH = "mistral-community/pixtral-12b"
|
34 |
|
35 |
llm_models = {
|
36 |
+
"Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2": None,
|
37 |
#PIXTRAL_PATH: None,
|
38 |
+
"bunnycore/LLama-3.1-8B-Matrix": None,
|
39 |
"Sao10K/Llama-3.1-8B-Stheno-v3.4": None,
|
40 |
"unsloth/Meta-Llama-3.1-8B-bnb-4bit": None,
|
41 |
"DevQuasar/HermesNova-Llama-3.1-8B": None,
|
|
|
158 |
else:
|
159 |
text_model = LlavaForConditionalGeneration.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16).eval()
|
160 |
image_adapter = AutoProcessor.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16)
|
161 |
+
tokenizer = None
|
162 |
+
peft_config = None
|
163 |
|
164 |
print("Loading tokenizer")
|
165 |
if gguf_file: tokenizer = AutoTokenizer.from_pretrained(model_name, gguf_file=gguf_file, use_fast=True, legacy=False)
|
|
|
220 |
load_text_model()
|
221 |
|
222 |
@spaces.GPU()
|
223 |
+
@torch.inference_mode()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: Union[str, int],
|
225 |
max_new_tokens: int=300, top_p: float=0.9, temperature: float=0.6, model_name: str=MODEL_PATH, progress=gr.Progress(track_tqdm=True)) -> str:
|
226 |
+
global tokenizer, text_model, image_adapter, peft_config, text_model_client, use_inference_client
|
227 |
torch.cuda.empty_cache()
|
228 |
gc.collect()
|
229 |
|
|
|
401 |
return gr.update(choices=get_text_model())
|
402 |
except Exception as e:
|
403 |
raise gr.Error(f"Model load error: {model_name}, {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|