Spaces:
Running
on
Zero
Running
on
Zero
Upload 2 files
Browse files- app.py +2 -2
- joycaption.py +21 -6
app.py
CHANGED
@@ -49,14 +49,14 @@ with gr.Blocks(fill_width=True, css=css, delete_cache=(60, 3600)) as demo:
|
|
49 |
jc_temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.6, step=0.1, label="Temperature")
|
50 |
jc_topp = gr.Slider(minimum=0, maximum=2.0, value=0.9, step=0.01, label="Top-P")
|
51 |
jc_run_button = gr.Button("Caption", variant="primary")
|
52 |
-
|
53 |
with gr.Column():
|
54 |
jc_output_caption = gr.Textbox(label="Caption", show_copy_button=True)
|
55 |
gr.Markdown(JC_DESC_MD, elem_classes="info")
|
56 |
gr.LoginButton()
|
57 |
gr.DuplicateButton(value="Duplicate Space for private use (This demo does not work on CPU. Requires GPU Space)")
|
58 |
|
59 |
-
jc_run_button.click(fn=stream_chat_mod, inputs=[jc_input_image, jc_caption_type, jc_caption_tone, jc_caption_length,
|
|
|
60 |
jc_text_model_button.click(change_text_model, [jc_text_model, jc_use_inference_client, jc_gguf, jc_nf4], [jc_text_model], show_api=False)
|
61 |
#jc_text_model.change(get_repo_gguf, [jc_text_model], [jc_gguf], show_api=False)
|
62 |
jc_use_inference_client.change(change_text_model, [jc_text_model, jc_use_inference_client], [jc_text_model], show_api=False)
|
|
|
49 |
jc_temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.6, step=0.1, label="Temperature")
|
50 |
jc_topp = gr.Slider(minimum=0, maximum=2.0, value=0.9, step=0.01, label="Top-P")
|
51 |
jc_run_button = gr.Button("Caption", variant="primary")
|
|
|
52 |
with gr.Column():
|
53 |
jc_output_caption = gr.Textbox(label="Caption", show_copy_button=True)
|
54 |
gr.Markdown(JC_DESC_MD, elem_classes="info")
|
55 |
gr.LoginButton()
|
56 |
gr.DuplicateButton(value="Duplicate Space for private use (This demo does not work on CPU. Requires GPU Space)")
|
57 |
|
58 |
+
jc_run_button.click(fn=stream_chat_mod, inputs=[jc_input_image, jc_caption_type, jc_caption_tone, jc_caption_length,
|
59 |
+
jc_tokens, jc_topp, jc_temperature, jc_text_model], outputs=[jc_output_caption])
|
60 |
jc_text_model_button.click(change_text_model, [jc_text_model, jc_use_inference_client, jc_gguf, jc_nf4], [jc_text_model], show_api=False)
|
61 |
#jc_text_model.change(get_repo_gguf, [jc_text_model], [jc_gguf], show_api=False)
|
62 |
jc_use_inference_client.change(change_text_model, [jc_text_model, jc_use_inference_client], [jc_text_model], show_api=False)
|
joycaption.py
CHANGED
@@ -30,9 +30,11 @@ BASE_DIR = Path(__file__).resolve().parent
|
|
30 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
31 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
32 |
use_inference_client = False
|
|
|
33 |
|
34 |
llm_models = {
|
35 |
"bunnycore/LLama-3.1-8B-Matrix": None,
|
|
|
36 |
"Sao10K/Llama-3.1-8B-Stheno-v3.4": None,
|
37 |
"unsloth/Meta-Llama-3.1-8B-bnb-4bit": None,
|
38 |
"DevQuasar/HermesNova-Llama-3.1-8B": None,
|
@@ -123,7 +125,6 @@ class ImageAdapter(nn.Module):
|
|
123 |
def get_eot_embedding(self):
|
124 |
return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
|
125 |
|
126 |
-
|
127 |
# https://huggingface.co/docs/transformers/v4.44.2/gguf
|
128 |
# https://github.com/city96/ComfyUI-GGUF/issues/7
|
129 |
# https://github.com/THUDM/ChatGLM-6B/issues/18
|
@@ -147,6 +148,15 @@ def load_text_model(model_name: str=MODEL_PATH, gguf_file: Union[str, None]=None
|
|
147 |
from transformers import BitsAndBytesConfig
|
148 |
nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
|
149 |
bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
print("Loading tokenizer")
|
152 |
if gguf_file: tokenizer = AutoTokenizer.from_pretrained(model_name, gguf_file=gguf_file, use_fast=True, legacy=False)
|
@@ -286,7 +296,8 @@ def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str,
|
|
286 |
|
287 |
@spaces.GPU()
|
288 |
@torch.no_grad()
|
289 |
-
def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: Union[str, int],
|
|
|
290 |
global use_inference_client, text_model
|
291 |
torch.cuda.empty_cache()
|
292 |
gc.collect()
|
@@ -312,8 +323,15 @@ def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_tone: s
|
|
312 |
prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(length=length, word_count=length)
|
313 |
print(f"Prompt: {prompt_str}")
|
314 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
315 |
# Preprocess image
|
316 |
-
#image = clip_processor(images=input_image, return_tensors='pt').pixel_values
|
317 |
image = input_image.resize((384, 384), Image.LANCZOS)
|
318 |
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
|
319 |
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
|
@@ -352,9 +370,6 @@ def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_tone: s
|
|
352 |
attention_mask = torch.ones_like(input_ids)
|
353 |
|
354 |
text_model.to(device)
|
355 |
-
#generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=False, suppress_tokens=None)
|
356 |
-
#generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, top_k=10, temperature=0.5, suppress_tokens=None)
|
357 |
-
#generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, suppress_tokens=None) # Uses the default which is temp=0.6, top_p=0.9
|
358 |
generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=max_new_tokens,
|
359 |
do_sample=True, suppress_tokens=None, top_p=top_p, temperature=temperature)
|
360 |
|
|
|
30 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
31 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
32 |
use_inference_client = False
|
33 |
+
PIXTRAL_PATH = "mistral-community/pixtral-12b"
|
34 |
|
35 |
llm_models = {
|
36 |
"bunnycore/LLama-3.1-8B-Matrix": None,
|
37 |
+
#PIXTRAL_PATH: None,
|
38 |
"Sao10K/Llama-3.1-8B-Stheno-v3.4": None,
|
39 |
"unsloth/Meta-Llama-3.1-8B-bnb-4bit": None,
|
40 |
"DevQuasar/HermesNova-Llama-3.1-8B": None,
|
|
|
125 |
def get_eot_embedding(self):
|
126 |
return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
|
127 |
|
|
|
128 |
# https://huggingface.co/docs/transformers/v4.44.2/gguf
|
129 |
# https://github.com/city96/ComfyUI-GGUF/issues/7
|
130 |
# https://github.com/THUDM/ChatGLM-6B/issues/18
|
|
|
148 |
from transformers import BitsAndBytesConfig
|
149 |
nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
|
150 |
bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
|
151 |
+
|
152 |
+
if model_name == PIXTRAL_PATH:
|
153 |
+
from transformers import AutoProcessor, LlavaForConditionalGeneration
|
154 |
+
if is_nf4:
|
155 |
+
text_model = LlavaForConditionalGeneration.from_pretrained(model_name, quantization_config=nf4_config, device_map=device, torch_dtype=torch.bfloat16).eval()
|
156 |
+
image_adapter = AutoProcessor.from_pretrained(model_name, quantization_config=nf4_config, device_map=device, torch_dtype=torch.bfloat16)
|
157 |
+
else:
|
158 |
+
text_model = LlavaForConditionalGeneration.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16).eval()
|
159 |
+
image_adapter = AutoProcessor.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16)
|
160 |
|
161 |
print("Loading tokenizer")
|
162 |
if gguf_file: tokenizer = AutoTokenizer.from_pretrained(model_name, gguf_file=gguf_file, use_fast=True, legacy=False)
|
|
|
296 |
|
297 |
@spaces.GPU()
|
298 |
@torch.no_grad()
|
299 |
+
def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: Union[str, int],
|
300 |
+
max_new_tokens: int=300, top_p: float=0.9, temperature: float=0.6, model_name: str=MODEL_PATH, progress=gr.Progress(track_tqdm=True)) -> str:
|
301 |
global use_inference_client, text_model
|
302 |
torch.cuda.empty_cache()
|
303 |
gc.collect()
|
|
|
323 |
prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(length=length, word_count=length)
|
324 |
print(f"Prompt: {prompt_str}")
|
325 |
|
326 |
+
# Pixtral
|
327 |
+
if model_name == PIXTRAL_PATH:
|
328 |
+
input_images = [input_image]
|
329 |
+
inputs = image_adapter(text=prompt_str, images=input_images, return_tensors="pt").to(device)
|
330 |
+
generate_ids = text_model.generate(**inputs, max_new_tokens=max_new_tokens)
|
331 |
+
output = image_adapter.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
332 |
+
return output.strip()
|
333 |
+
|
334 |
# Preprocess image
|
|
|
335 |
image = input_image.resize((384, 384), Image.LANCZOS)
|
336 |
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
|
337 |
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
|
|
|
370 |
attention_mask = torch.ones_like(input_ids)
|
371 |
|
372 |
text_model.to(device)
|
|
|
|
|
|
|
373 |
generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=max_new_tokens,
|
374 |
do_sample=True, suppress_tokens=None, top_p=top_p, temperature=temperature)
|
375 |
|