John6666 commited on
Commit
c796bf8
·
verified ·
1 Parent(s): bfd1b89

Upload joycaption.py

Browse files
Files changed (1) hide show
  1. joycaption.py +4 -4
joycaption.py CHANGED
@@ -30,11 +30,11 @@ BASE_DIR = Path(__file__).resolve().parent
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
32
  use_inference_client = False
33
- PIXTRAL_PATH = "mistral-community/pixtral-12b"
34
 
35
  llm_models = {
36
  "Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2": None,
37
- #PIXTRAL_PATH: None,
38
  "bunnycore/LLama-3.1-8B-Matrix": None,
39
  "Sao10K/Llama-3.1-8B-Stheno-v3.4": None,
40
  "unsloth/Meta-Llama-3.1-8B-bnb-4bit": None,
@@ -150,7 +150,7 @@ def load_text_model(model_name: str=MODEL_PATH, gguf_file: Union[str, None]=None
150
  nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
151
  bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
152
 
153
- if model_name == PIXTRAL_PATH:
154
  from transformers import AutoProcessor, LlavaForConditionalGeneration
155
  if is_nf4:
156
  text_model = LlavaForConditionalGeneration.from_pretrained(model_name, quantization_config=nf4_config, device_map=device, torch_dtype=torch.bfloat16).eval()
@@ -249,7 +249,7 @@ def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_tone: s
249
  print(f"Prompt: {prompt_str}")
250
 
251
  # Pixtral
252
- if model_name == PIXTRAL_PATH:
253
  input_images = [input_image]
254
  inputs = image_adapter(text=prompt_str, images=input_images, return_tensors="pt").to(device)
255
  generate_ids = text_model.generate(**inputs, max_new_tokens=max_new_tokens)
 
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
32
  use_inference_client = False
33
+ PIXTRAL_PATHS = ["mistral-community/pixtral-12b"]
34
 
35
  llm_models = {
36
  "Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2": None,
37
+ #PIXTRAL_PATHS[0]: None,
38
  "bunnycore/LLama-3.1-8B-Matrix": None,
39
  "Sao10K/Llama-3.1-8B-Stheno-v3.4": None,
40
  "unsloth/Meta-Llama-3.1-8B-bnb-4bit": None,
 
150
  nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
151
  bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
152
 
153
+ if model_name in PIXTRAL_PATHS:
154
  from transformers import AutoProcessor, LlavaForConditionalGeneration
155
  if is_nf4:
156
  text_model = LlavaForConditionalGeneration.from_pretrained(model_name, quantization_config=nf4_config, device_map=device, torch_dtype=torch.bfloat16).eval()
 
249
  print(f"Prompt: {prompt_str}")
250
 
251
  # Pixtral
252
+ if model_name in PIXTRAL_PATHS:
253
  input_images = [input_image]
254
  inputs = image_adapter(text=prompt_str, images=input_images, return_tensors="pt").to(device)
255
  generate_ids = text_model.generate(**inputs, max_new_tokens=max_new_tokens)