switch to qwen2.5 vl

#2
by davanstrien HF staff - opened
Files changed (1) hide show
  1. app.py +68 -39
app.py CHANGED
@@ -1,14 +1,15 @@
1
- # import subprocess # πŸ₯²
2
-
3
- # subprocess.run(
4
- # "pip install flash-attn --no-build-isolation",
5
- # env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
6
- # shell=True,
7
- # )
8
 
 
 
 
 
 
9
  import spaces
10
  import gradio as gr
11
- from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
 
 
12
  import torch
13
  import os
14
  import json
@@ -17,19 +18,15 @@ from typing import Tuple
17
 
18
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
19
 
20
- # Load Molmo model
21
- model = AutoModelForCausalLM.from_pretrained(
22
- 'allenai/Molmo-7B-D-0924',
23
- trust_remote_code=True,
24
- torch_dtype='auto',
25
- device_map='auto'
26
- )
27
- processor = AutoProcessor.from_pretrained(
28
- 'allenai/Molmo-7B-D-0924',
29
- trust_remote_code=True,
30
- torch_dtype='auto',
31
- device_map='auto'
32
  )
 
 
33
 
34
  class GeneralRetrievalQuery(BaseModel):
35
  broad_topical_query: str
@@ -39,6 +36,7 @@ class GeneralRetrievalQuery(BaseModel):
39
  visual_element_query: str
40
  visual_element_explanation: str
41
 
 
42
  def get_retrieval_prompt(prompt_name: str) -> Tuple[str, GeneralRetrievalQuery]:
43
  if prompt_name != "general":
44
  raise ValueError("Only 'general' prompt is available in this version")
@@ -72,46 +70,77 @@ Format your response as a JSON object with the following structure:
72
  If there are no relevant visual elements, replace the third query with another specific detail query.
73
 
74
  Here is the document image to analyze:
 
75
 
76
- Generate the queries based on this image and provide the response in the specified JSON format.
77
- Only return JSON. Don't return any extra explanation text. """
78
 
79
  return prompt, GeneralRetrievalQuery
80
 
 
 
81
  prompt, pydantic_model = get_retrieval_prompt("general")
82
 
 
83
  def _prep_data_for_input(image):
84
- return processor.process(
85
- images=[image],
86
- text=prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  )
88
 
89
- @spaces.GPU(duration=120)
 
90
  def generate_response(image):
91
  inputs = _prep_data_for_input(image)
92
- inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
93
- output = model.generate_from_batch(
94
- inputs,
95
- GenerationConfig(max_new_tokens=800, stop_token="<|endoftext|>"),
96
- tokenizer=processor.tokenizer
 
 
 
 
 
 
 
97
  )
98
- generated_tokens = output[0, inputs['input_ids'].size(1):]
99
- output_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
100
-
101
  try:
102
- return str(json.loads(output_text))
103
  except Exception:
104
  gr.Warning("Failed to parse JSON from output")
105
- return output_text
 
106
 
107
- title = "ColPali fine-tuning Query Generator"
108
  description = """[ColPali](https://huggingface.co/papers/2407.01449) is a very exciting new approach to multimodal document retrieval which aims to replace existing document retrievers which often rely on an OCR step with an end-to-end multimodal approach.
109
 
110
  To train or fine-tune a ColPali model, we need a dataset of image-text pairs which represent the document images and the relevant text queries which those documents should match.
111
  To make the ColPali models work even better we might want a dataset of query/image document pairs related to our domain or task.
112
 
113
  One way in which we might go about generating such a dataset is to use a VLM to generate synthetic queries for us.
114
- This space uses the [allenai/Molmo-7B-D-0924](https://huggingface.co/allenai/Molmo-7B-D-0924) model to generate queries for a document, based on an input document image.
115
 
116
  **Note** there is a lot of scope for improving to prompts and the quality of the generated queries! If you have any suggestions for improvements please [open a Discussion](https://huggingface.co/spaces/davanstrien/ColPali-Query-Generator/discussions/new)!
117
 
@@ -128,7 +157,7 @@ examples = [
128
  demo = gr.Interface(
129
  fn=generate_response,
130
  inputs=gr.Image(type="pil"),
131
- outputs=gr.Text(),
132
  title=title,
133
  description=description,
134
  examples=examples,
 
1
+ import subprocess # πŸ₯²
 
 
 
 
 
 
2
 
3
+ subprocess.run(
4
+ "pip install flash-attn --no-build-isolation",
5
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
6
+ shell=True,
7
+ )
8
  import spaces
9
  import gradio as gr
10
+
11
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
12
+ from qwen_vl_utils import process_vision_info
13
  import torch
14
  import os
15
  import json
 
18
 
19
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
20
 
21
+
22
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
23
+ "Qwen/Qwen2.5-VL-7B-Instruct",
24
+ torch_dtype=torch.bfloat16,
25
+ attn_implementation="flash_attention_2",
26
+ device_map="auto",
 
 
 
 
 
 
27
  )
28
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
29
+
30
 
31
  class GeneralRetrievalQuery(BaseModel):
32
  broad_topical_query: str
 
36
  visual_element_query: str
37
  visual_element_explanation: str
38
 
39
+
40
  def get_retrieval_prompt(prompt_name: str) -> Tuple[str, GeneralRetrievalQuery]:
41
  if prompt_name != "general":
42
  raise ValueError("Only 'general' prompt is available in this version")
 
70
  If there are no relevant visual elements, replace the third query with another specific detail query.
71
 
72
  Here is the document image to analyze:
73
+ <image>
74
 
75
+ Generate the queries based on this image and provide the response in the specified JSON format."""
 
76
 
77
  return prompt, GeneralRetrievalQuery
78
 
79
+
80
+ # defined like this so we can later add more prompting options
81
  prompt, pydantic_model = get_retrieval_prompt("general")
82
 
83
+
84
  def _prep_data_for_input(image):
85
+ messages = [
86
+ {
87
+ "role": "user",
88
+ "content": [
89
+ {
90
+ "type": "image",
91
+ "image": image,
92
+ },
93
+ {"type": "text", "text": prompt},
94
+ ],
95
+ }
96
+ ]
97
+
98
+ text = processor.apply_chat_template(
99
+ messages, tokenize=False, add_generation_prompt=True
100
+ )
101
+
102
+ image_inputs, video_inputs = process_vision_info(messages)
103
+
104
+ return processor(
105
+ text=[text],
106
+ images=image_inputs,
107
+ videos=video_inputs,
108
+ padding=True,
109
+ return_tensors="pt",
110
  )
111
 
112
+
113
+ @spaces.GPU
114
  def generate_response(image):
115
  inputs = _prep_data_for_input(image)
116
+ inputs = inputs.to("cuda")
117
+
118
+ generated_ids = model.generate(**inputs, max_new_tokens=200)
119
+ generated_ids_trimmed = [
120
+ out_ids[len(in_ids) :]
121
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
122
+ ]
123
+
124
+ output_text = processor.batch_decode(
125
+ generated_ids_trimmed,
126
+ skip_special_tokens=True,
127
+ clean_up_tokenization_spaces=False,
128
  )
 
 
 
129
  try:
130
+ return json.loads(output_text[0])
131
  except Exception:
132
  gr.Warning("Failed to parse JSON from output")
133
+ return {}
134
+
135
 
136
+ title = "ColPali Query Generator using Qwen2.5-VL"
137
  description = """[ColPali](https://huggingface.co/papers/2407.01449) is a very exciting new approach to multimodal document retrieval which aims to replace existing document retrievers which often rely on an OCR step with an end-to-end multimodal approach.
138
 
139
  To train or fine-tune a ColPali model, we need a dataset of image-text pairs which represent the document images and the relevant text queries which those documents should match.
140
  To make the ColPali models work even better we might want a dataset of query/image document pairs related to our domain or task.
141
 
142
  One way in which we might go about generating such a dataset is to use a VLM to generate synthetic queries for us.
143
+ This space uses the [Qwen/Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) VLM model to generate queries for a document, based on an input document image.
144
 
145
  **Note** there is a lot of scope for improving to prompts and the quality of the generated queries! If you have any suggestions for improvements please [open a Discussion](https://huggingface.co/spaces/davanstrien/ColPali-Query-Generator/discussions/new)!
146
 
 
157
  demo = gr.Interface(
158
  fn=generate_response,
159
  inputs=gr.Image(type="pil"),
160
+ outputs=gr.Json(),
161
  title=title,
162
  description=description,
163
  examples=examples,