File size: 5,783 Bytes
9e7d682
ec8173c
9e7d682
 
 
 
 
3f53d8e
4e1ec1c
4eac50b
3555196
4e1ec1c
 
 
1796549
 
4eac50b
4e1ec1c
1796549
ec8173c
 
 
 
 
 
 
 
 
 
 
 
4e1ec1c
1796549
4e1ec1c
 
 
 
 
 
 
 
be1e49c
 
 
 
 
4e1ec1c
 
 
 
 
ff86a3f
4e1ec1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e7d682
efd5a7f
4e1ec1c
be1e49c
4e1ec1c
 
 
be1e49c
ec8173c
 
 
4e1ec1c
 
38576ff
be1e49c
 
ec8173c
 
 
9e7d682
ec8173c
4e1ec1c
ec8173c
 
 
4e1ec1c
38576ff
4e1ec1c
7b6a165
38576ff
4e1ec1c
e3dcfdd
 
6c3e99a
 
e3dcfdd
6c3e99a
ec8173c
 
e3dcfdd
b42e8c6
10a9ffa
 
e3dcfdd
 
10a9ffa
7b6a165
1d82c63
be1e49c
 
1d82c63
 
7b6a165
 
be1e49c
38576ff
7b6a165
 
1d82c63
7b6a165
ec8173c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# import subprocess  # 🥲

# subprocess.run(
#     "pip install flash-attn --no-build-isolation",
#     env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
#     shell=True,
# )

import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
import torch
import os
import json
from pydantic import BaseModel
from typing import Tuple

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

# Load Molmo model
model = AutoModelForCausalLM.from_pretrained(
    'allenai/Molmo-7B-D-0924',
    trust_remote_code=True,
    torch_dtype='auto',
    device_map='auto'
)
processor = AutoProcessor.from_pretrained(
    'allenai/Molmo-7B-D-0924',
    trust_remote_code=True,
    torch_dtype='auto',
    device_map='auto'
)

class GeneralRetrievalQuery(BaseModel):
    broad_topical_query: str
    broad_topical_explanation: str
    specific_detail_query: str
    specific_detail_explanation: str
    visual_element_query: str
    visual_element_explanation: str

def get_retrieval_prompt(prompt_name: str) -> Tuple[str, GeneralRetrievalQuery]:
    if prompt_name != "general":
        raise ValueError("Only 'general' prompt is available in this version")

    prompt = """You are an AI assistant specialized in document retrieval tasks. Given an image of a document page, your task is to generate retrieval queries that someone might use to find this document in a large corpus.

Please generate 3 different types of retrieval queries:

1. A broad topical query: This should cover the main subject of the document.
2. A specific detail query: This should focus on a particular fact, figure, or point made in the document.
3. A visual element query: This should reference a chart, graph, image, or other visual component in the document, if present. Don't just reference the name of the visual element but generate a query which this illustration may help answer or be related to.

Important guidelines:
- Ensure the queries are relevant for retrieval tasks, not just describing the page content.
- Frame the queries as if someone is searching for this document, not asking questions about its content.
- Make the queries diverse and representative of different search strategies.

For each query, also provide a brief explanation of why this query would be effective in retrieving this document.

Format your response as a JSON object with the following structure:

{
  "broad_topical_query": "Your query here",
  "broad_topical_explanation": "Brief explanation",
  "specific_detail_query": "Your query here",
  "specific_detail_explanation": "Brief explanation",
  "visual_element_query": "Your query here",
  "visual_element_explanation": "Brief explanation"
}

If there are no relevant visual elements, replace the third query with another specific detail query.

Here is the document image to analyze:

Generate the queries based on this image and provide the response in the specified JSON format.
Only return JSON. Don't return any extra explanation text. """

    return prompt, GeneralRetrievalQuery

prompt, pydantic_model = get_retrieval_prompt("general")

def _prep_data_for_input(image):
    return processor.process(
        images=[image],
        text=prompt
    )

@spaces.GPU(duration=120)
def generate_response(image):
    inputs = _prep_data_for_input(image)
    inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
    output = model.generate_from_batch(
        inputs,
        GenerationConfig(max_new_tokens=800, stop_token="<|endoftext|>"),
        tokenizer=processor.tokenizer
    )
    generated_tokens = output[0, inputs['input_ids'].size(1):]
    output_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)

    try:
        return str(json.loads(output_text))
    except Exception:
        gr.Warning("Failed to parse JSON from output")
        return output_text

title = "ColPali fine-tuning Query Generator"
description = """[ColPali](https://huggingface.co/papers/2407.01449) is a very exciting new approach to multimodal document retrieval which aims to replace existing document retrievers which often rely on an OCR step with an end-to-end multimodal approach. 

To train or fine-tune a ColPali model, we need a dataset of image-text pairs which represent the document images and the relevant text queries which those documents should match. 
To make the ColPali models work even better we might want a dataset of query/image document pairs related to our domain or task. 

One way in which we might go about generating such a dataset is to use a VLM to generate synthetic queries for us. 
This space uses the [allenai/Molmo-7B-D-0924](https://huggingface.co/allenai/Molmo-7B-D-0924) model to generate queries for a document, based on an input document image. 

**Note** there is a lot of scope for improving to prompts and the quality of the generated queries! If you have any suggestions for improvements please [open a Discussion](https://huggingface.co/spaces/davanstrien/ColPali-Query-Generator/discussions/new)!

This [blog post](https://danielvanstrien.xyz/posts/post-with-code/colpali/2024-09-23-generate_colpali_dataset.html) gives an overview of how you can use this kind of approach to generate a full dataset for fine-tuning ColPali models. 

If you want to convert a PDF(s) to a dataset of page images you can try out the [ PDFs to Page Images Converter](https://huggingface.co/spaces/Dataset-Creation-Tools/pdf-to-page-images-dataset) Space.
"""

examples = [
    "examples/Approche_no_13_1977.pdf_page_22.jpg",
    "examples/SRCCL_Technical-Summary.pdf_page_7.jpg",
]

demo = gr.Interface(
    fn=generate_response,
    inputs=gr.Image(type="pil"),
    outputs=gr.Text(),
    title=title,
    description=description,
    examples=examples,
)
demo.launch()