John6666 commited on
Commit
cd64e85
·
verified ·
1 Parent(s): 8e6bf35

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. joycaption.py +29 -24
app.py CHANGED
@@ -4,7 +4,7 @@ from joycaption import stream_chat_mod, get_text_model, change_text_model, get_r
4
 
5
  JC_TITLE_MD = "<h1><center>JoyCaption Alpha One Mod</center></h1>"
6
  JC_DESC_MD = """This space is mod of [fancyfeast/joy-caption-alpha-one](https://huggingface.co/spaces/fancyfeast/joy-caption-alpha-one),
7
- [Wi-zz/joy-caption-pre-alpha](https://huggingface.co/Wi-zz/joy-caption-pre-alpha)"""
8
 
9
  css = """
10
  .info {text-align:center; !important}
 
4
 
5
  JC_TITLE_MD = "<h1><center>JoyCaption Alpha One Mod</center></h1>"
6
  JC_DESC_MD = """This space is mod of [fancyfeast/joy-caption-alpha-one](https://huggingface.co/spaces/fancyfeast/joy-caption-alpha-one),
7
+ [Wi-zz/joy-caption-pre-alpha](https://huggingface.co/Wi-zz/joy-caption-pre-alpha). Thanks to [dominic1021](https://huggingface.co/dominic1021)"""
8
 
9
  css = """
10
  .info {text-align:center; !important}
joycaption.py CHANGED
@@ -19,10 +19,14 @@ from PIL import Image
19
  import torchvision.transforms.functional as TVF
20
  import gc
21
  from peft import PeftConfig
 
22
 
23
  import subprocess
24
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
25
 
 
 
 
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
28
  use_inference_client = False
@@ -38,7 +42,7 @@ llm_models = {
38
 
39
  CLIP_PATH = "google/siglip-so400m-patch14-384"
40
  MODEL_PATH = list(llm_models.keys())[0]
41
- CHECKPOINT_PATH = Path("9em124t2-499968")
42
  LORA_PATH = CHECKPOINT_PATH / "text_model"
43
  TITLE = "<h1><center>JoyCaption Alpha One (2024-09-20a)</center></h1>"
44
  CAPTION_TYPE_MAP = {
@@ -137,36 +141,41 @@ text_model_client = None
137
  text_model = None
138
  image_adapter = None
139
  peft_config = None
140
- def load_text_model(model_name: str=MODEL_PATH, gguf_file: str | None=None, is_nf4: bool=True):
141
- global tokenizer
142
- global text_model
143
- global image_adapter
144
- global peft_config
145
- global text_model_client #
146
- global use_inference_client #
147
  try:
148
  from transformers import BitsAndBytesConfig
149
  nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
150
  bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
 
151
  print("Loading tokenizer")
152
  if gguf_file: tokenizer = AutoTokenizer.from_pretrained(model_name, gguf_file=gguf_file, use_fast=True, legacy=False)
153
  else: tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, legacy=False)
154
  assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}"
 
155
  print(f"Loading LLM: {model_name}")
156
  if gguf_file:
157
- if device == "cpu": text_model = AutoModelForCausalLM.from_pretrained(model_name, gguf_file=gguf_file, device_map=device, torch_dtype=torch.bfloat16).eval()
158
- elif is_nf4: text_model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=nf4_config, device_map=device, torch_dtype=torch.bfloat16).eval()
159
- else: text_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16).eval()
 
 
 
160
  else:
161
- if device == "cpu": text_model = AutoModelForCausalLM.from_pretrained(model_name, gguf_file=gguf_file, device_map=device, torch_dtype=torch.bfloat16).eval()
162
- elif is_nf4: text_model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=nf4_config, device_map=device, torch_dtype=torch.bfloat16).eval()
163
- else: text_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16).eval()
 
 
 
 
164
  if LORA_PATH.exists():
165
  print("Loading VLM's custom text model")
166
  if is_nf4: peft_config = PeftConfig.from_pretrained(LORA_PATH, device_map=device, quantization_config=nf4_config)
167
  else: peft_config = PeftConfig.from_pretrained(LORA_PATH, device_map=device)
168
  text_model.add_adapter(peft_config)
169
  text_model.enable_adapters()
 
170
  print("Loading image adapter")
171
  image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False).eval().to("cpu")
172
  image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu", weights_only=True))
@@ -186,7 +195,7 @@ clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
186
  clip_model = AutoModel.from_pretrained(CLIP_PATH).vision_model
187
  if (CHECKPOINT_PATH / "clip_model.pt").exists():
188
  print("Loading VLM's custom vision model")
189
- checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu')
190
  checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()}
191
  clip_model.load_state_dict(checkpoint)
192
  del checkpoint
@@ -197,10 +206,9 @@ clip_model.eval().requires_grad_(False).to(device)
197
  # Image Adapter
198
  load_text_model()
199
 
200
-
201
  @spaces.GPU()
202
  @torch.no_grad()
203
- def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int) -> str:
204
  torch.cuda.empty_cache()
205
 
206
  # 'any' means no length specified
@@ -276,12 +284,10 @@ def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str,
276
 
277
  return caption.strip()
278
 
279
-
280
  @spaces.GPU()
281
  @torch.no_grad()
282
- def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int, max_new_tokens: int=300, top_p: float=0.9, temperature: float=0.6, progress=gr.Progress(track_tqdm=True)) -> str:
283
- global use_inference_client
284
- global text_model
285
  torch.cuda.empty_cache()
286
  gc.collect()
287
 
@@ -437,10 +443,9 @@ def get_repo_gguf(repo_id: str):
437
 
438
 
439
  @spaces.GPU()
440
- def change_text_model(model_name: str=MODEL_PATH, use_client: bool=False, gguf_file: str | None=None,
441
  is_nf4: bool=True, progress=gr.Progress(track_tqdm=True)):
442
- global use_inference_client
443
- global llm_models
444
  use_inference_client = use_client
445
  try:
446
  if not is_repo_name(model_name) or not is_repo_exists(model_name):
 
19
  import torchvision.transforms.functional as TVF
20
  import gc
21
  from peft import PeftConfig
22
+ from typing import Union
23
 
24
  import subprocess
25
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
26
 
27
+ # Define the base directory
28
+ BASE_DIR = Path(__file__).resolve().parent
29
+
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
32
  use_inference_client = False
 
42
 
43
  CLIP_PATH = "google/siglip-so400m-patch14-384"
44
  MODEL_PATH = list(llm_models.keys())[0]
45
+ CHECKPOINT_PATH = BASE_DIR / Path("9em124t2-499968")
46
  LORA_PATH = CHECKPOINT_PATH / "text_model"
47
  TITLE = "<h1><center>JoyCaption Alpha One (2024-09-20a)</center></h1>"
48
  CAPTION_TYPE_MAP = {
 
141
  text_model = None
142
  image_adapter = None
143
  peft_config = None
144
+ def load_text_model(model_name: str=MODEL_PATH, gguf_file: Union[str, None]=None, is_nf4: bool=True):
145
+ global tokenizer, text_model, image_adapter, peft_config, text_model_client, use_inference_client
 
 
 
 
 
146
  try:
147
  from transformers import BitsAndBytesConfig
148
  nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
149
  bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
150
+
151
  print("Loading tokenizer")
152
  if gguf_file: tokenizer = AutoTokenizer.from_pretrained(model_name, gguf_file=gguf_file, use_fast=True, legacy=False)
153
  else: tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, legacy=False)
154
  assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}"
155
+
156
  print(f"Loading LLM: {model_name}")
157
  if gguf_file:
158
+ if device == "cpu":
159
+ text_model = AutoModelForCausalLM.from_pretrained(model_name, gguf_file=gguf_file, device_map=device, torch_dtype=torch.bfloat16).eval()
160
+ elif is_nf4:
161
+ text_model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=nf4_config, device_map=device, torch_dtype=torch.bfloat16).eval()
162
+ else:
163
+ text_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16).eval()
164
  else:
165
+ if device == "cpu":
166
+ text_model = AutoModelForCausalLM.from_pretrained(model_name, gguf_file=gguf_file, device_map=device, torch_dtype=torch.bfloat16).eval()
167
+ elif is_nf4:
168
+ text_model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=nf4_config, device_map=device, torch_dtype=torch.bfloat16).eval()
169
+ else:
170
+ text_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16).eval()
171
+
172
  if LORA_PATH.exists():
173
  print("Loading VLM's custom text model")
174
  if is_nf4: peft_config = PeftConfig.from_pretrained(LORA_PATH, device_map=device, quantization_config=nf4_config)
175
  else: peft_config = PeftConfig.from_pretrained(LORA_PATH, device_map=device)
176
  text_model.add_adapter(peft_config)
177
  text_model.enable_adapters()
178
+
179
  print("Loading image adapter")
180
  image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False).eval().to("cpu")
181
  image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu", weights_only=True))
 
195
  clip_model = AutoModel.from_pretrained(CLIP_PATH).vision_model
196
  if (CHECKPOINT_PATH / "clip_model.pt").exists():
197
  print("Loading VLM's custom vision model")
198
+ checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu', weights_only=True)
199
  checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()}
200
  clip_model.load_state_dict(checkpoint)
201
  del checkpoint
 
206
  # Image Adapter
207
  load_text_model()
208
 
 
209
  @spaces.GPU()
210
  @torch.no_grad()
211
+ def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: Union[str, int]) -> str:
212
  torch.cuda.empty_cache()
213
 
214
  # 'any' means no length specified
 
284
 
285
  return caption.strip()
286
 
 
287
  @spaces.GPU()
288
  @torch.no_grad()
289
+ def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: Union[str, int], max_new_tokens: int=300, top_p: float=0.9, temperature: float=0.6, progress=gr.Progress(track_tqdm=True)) -> str:
290
+ global use_inference_client, text_model
 
291
  torch.cuda.empty_cache()
292
  gc.collect()
293
 
 
443
 
444
 
445
  @spaces.GPU()
446
+ def change_text_model(model_name: str=MODEL_PATH, use_client: bool=False, gguf_file: Union[str, None]=None,
447
  is_nf4: bool=True, progress=gr.Progress(track_tqdm=True)):
448
+ global use_inference_client, llm_models
 
449
  use_inference_client = use_client
450
  try:
451
  if not is_repo_name(model_name) or not is_repo_exists(model_name):