John6666 commited on
Commit
65832a2
·
verified ·
1 Parent(s): f902fc6

Upload joycaption.py

Browse files
Files changed (1) hide show
  1. joycaption.py +4 -6
joycaption.py CHANGED
@@ -11,12 +11,15 @@ import os
11
  import gc
12
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
14
 
15
  llm_models = {
16
  "Sao10K/Llama-3.1-8B-Stheno-v3.4": None,
17
  "unsloth/Meta-Llama-3.1-8B-bnb-4bit": None,
 
18
  "mergekit-community/L3.1-Boshima-b-FIX": None,
19
- "meta-llama/Meta-Llama-3.1-8B": None,
20
  }
21
 
22
  CLIP_PATH = "google/siglip-so400m-patch14-384"
@@ -25,9 +28,6 @@ MODEL_PATH = list(llm_models.keys())[0]
25
  CHECKPOINT_PATH = Path("wpkklhc6")
26
  TITLE = "<h1><center>JoyCaption Pre-Alpha (2024-07-30a)</center></h1>"
27
 
28
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
29
- use_inference_client = False
30
-
31
  class ImageAdapter(nn.Module):
32
  def __init__(self, input_features: int, output_features: int):
33
  super().__init__()
@@ -200,8 +200,6 @@ def stream_chat_mod(input_image: Image.Image, max_new_tokens: int=300, top_k: in
200
  #generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=False, suppress_tokens=None)
201
  generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
202
  max_new_tokens=max_new_tokens, do_sample=True, top_k=top_k, temperature=temperature, suppress_tokens=None)
203
-
204
- print(prompt)
205
 
206
  # Trim off the prompt
207
  generate_ids = generate_ids[:, input_ids.shape[1]:]
 
11
  import gc
12
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
15
+ use_inference_client = False
16
 
17
  llm_models = {
18
  "Sao10K/Llama-3.1-8B-Stheno-v3.4": None,
19
  "unsloth/Meta-Llama-3.1-8B-bnb-4bit": None,
20
+ "DevQuasar/HermesNova-Llama-3.1-8B": None,
21
  "mergekit-community/L3.1-Boshima-b-FIX": None,
22
+ "meta-llama/Meta-Llama-3.1-8B": None, # gated
23
  }
24
 
25
  CLIP_PATH = "google/siglip-so400m-patch14-384"
 
28
  CHECKPOINT_PATH = Path("wpkklhc6")
29
  TITLE = "<h1><center>JoyCaption Pre-Alpha (2024-07-30a)</center></h1>"
30
 
 
 
 
31
  class ImageAdapter(nn.Module):
32
  def __init__(self, input_features: int, output_features: int):
33
  super().__init__()
 
200
  #generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=False, suppress_tokens=None)
201
  generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
202
  max_new_tokens=max_new_tokens, do_sample=True, top_k=top_k, temperature=temperature, suppress_tokens=None)
 
 
203
 
204
  # Trim off the prompt
205
  generate_ids = generate_ids[:, input_ids.shape[1]:]