Spaces:
Running
on
Zero
Running
on
Zero
Upload joycaption.py
Browse files- joycaption.py +4 -4
joycaption.py
CHANGED
@@ -30,11 +30,11 @@ BASE_DIR = Path(__file__).resolve().parent
|
|
30 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
31 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
32 |
use_inference_client = False
|
33 |
-
|
34 |
|
35 |
llm_models = {
|
36 |
"Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2": None,
|
37 |
-
#
|
38 |
"bunnycore/LLama-3.1-8B-Matrix": None,
|
39 |
"Sao10K/Llama-3.1-8B-Stheno-v3.4": None,
|
40 |
"unsloth/Meta-Llama-3.1-8B-bnb-4bit": None,
|
@@ -150,7 +150,7 @@ def load_text_model(model_name: str=MODEL_PATH, gguf_file: Union[str, None]=None
|
|
150 |
nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
|
151 |
bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
|
152 |
|
153 |
-
if model_name
|
154 |
from transformers import AutoProcessor, LlavaForConditionalGeneration
|
155 |
if is_nf4:
|
156 |
text_model = LlavaForConditionalGeneration.from_pretrained(model_name, quantization_config=nf4_config, device_map=device, torch_dtype=torch.bfloat16).eval()
|
@@ -249,7 +249,7 @@ def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_tone: s
|
|
249 |
print(f"Prompt: {prompt_str}")
|
250 |
|
251 |
# Pixtral
|
252 |
-
if model_name
|
253 |
input_images = [input_image]
|
254 |
inputs = image_adapter(text=prompt_str, images=input_images, return_tensors="pt").to(device)
|
255 |
generate_ids = text_model.generate(**inputs, max_new_tokens=max_new_tokens)
|
|
|
30 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
31 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
32 |
use_inference_client = False
|
33 |
+
PIXTRAL_PATHS = ["mistral-community/pixtral-12b"]
|
34 |
|
35 |
llm_models = {
|
36 |
"Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2": None,
|
37 |
+
#PIXTRAL_PATHS[0]: None,
|
38 |
"bunnycore/LLama-3.1-8B-Matrix": None,
|
39 |
"Sao10K/Llama-3.1-8B-Stheno-v3.4": None,
|
40 |
"unsloth/Meta-Llama-3.1-8B-bnb-4bit": None,
|
|
|
150 |
nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
|
151 |
bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
|
152 |
|
153 |
+
if model_name in PIXTRAL_PATHS:
|
154 |
from transformers import AutoProcessor, LlavaForConditionalGeneration
|
155 |
if is_nf4:
|
156 |
text_model = LlavaForConditionalGeneration.from_pretrained(model_name, quantization_config=nf4_config, device_map=device, torch_dtype=torch.bfloat16).eval()
|
|
|
249 |
print(f"Prompt: {prompt_str}")
|
250 |
|
251 |
# Pixtral
|
252 |
+
if model_name in PIXTRAL_PATHS:
|
253 |
input_images = [input_image]
|
254 |
inputs = image_adapter(text=prompt_str, images=input_images, return_tensors="pt").to(device)
|
255 |
generate_ids = text_model.generate(**inputs, max_new_tokens=max_new_tokens)
|