figh8back commited on
Commit
dbacd88
1 Parent(s): 5ea229c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -0
app.py CHANGED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoProcessor, Idefics2ForConditionalGeneration
3
+ import re
4
+ import time
5
+ from PIL import Image
6
+ import torch
7
+ import spaces
8
+ import subprocess
9
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
10
+
11
+
12
+ processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b")
13
+
14
+ model = Idefics2ForConditionalGeneration.from_pretrained(
15
+ "HuggingFaceM4/idefics2-8b",
16
+ torch_dtype=torch.bfloat16,
17
+ #_attn_implementation="flash_attention_2",
18
+ trust_remote_code=True).to("cuda")
19
+
20
+ @spaces.GPU(duration=180)
21
+ def model_inference(
22
+ image, text, decoding_strategy, temperature,
23
+ max_new_tokens, repetition_penalty, top_p
24
+ ):
25
+ if text == "" and not image:
26
+ gr.Error("Please input a query and optionally image(s).")
27
+
28
+ if text == "" and image:
29
+ gr.Error("Please input a text query along the image(s).")
30
+
31
+ resulting_messages = [
32
+ {
33
+ "role": "user",
34
+ "content": [{"type": "image"}] + [
35
+ {"type": "text", "text": text}
36
+ ]
37
+ }
38
+ ]
39
+
40
+
41
+ prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
42
+ inputs = processor(text=prompt, images=[image], return_tensors="pt")
43
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
44
+
45
+ generation_args = {
46
+ "max_new_tokens": max_new_tokens,
47
+ "repetition_penalty": repetition_penalty,
48
+
49
+ }
50
+
51
+ assert decoding_strategy in [
52
+ "Greedy",
53
+ "Top P Sampling",
54
+ ]
55
+ if decoding_strategy == "Greedy":
56
+ generation_args["do_sample"] = False
57
+ elif decoding_strategy == "Top P Sampling":
58
+ generation_args["temperature"] = temperature
59
+ generation_args["do_sample"] = True
60
+ generation_args["top_p"] = top_p
61
+
62
+
63
+ generation_args.update(inputs)
64
+
65
+ # Generate
66
+ generated_ids = model.generate(**generation_args)
67
+
68
+ generated_texts = processor.batch_decode(generated_ids[:, generation_args["input_ids"].size(1):], skip_special_tokens=True)
69
+ print("INPUT:", prompt, "|OUTPUT:", generated_texts)
70
+ return generated_texts[0]
71
+
72
+
73
+ with gr.Blocks(fill_height=True) as demo:
74
+ gr.Markdown("## IDEFICS2 Instruction 🐶")
75
+ gr.Markdown("Play with [IDEFICS2-8B](https://huggingface.co/HuggingFaceM4/idefics2-8b) in this demo. To get started, upload an image and text or try one of the examples.")
76
+ gr.Markdown("**Important note**: This model is not made for chatting, the chatty IDEFICS2 will be released in the upcoming days. **This model is very strong on various tasks, including visual question answering, document retrieval and more, you can see it through the examples.**")
77
+ gr.Markdown("Learn more about IDEFICS2 in this [blog post](https://huggingface.co/blog/idefics2).")
78
+
79
+
80
+ with gr.Column():
81
+ image_input = gr.Image(label="Upload your Image", type="pil")
82
+ query_input = gr.Textbox(label="Prompt")
83
+ submit_btn = gr.Button("Submit")
84
+ output = gr.Textbox(label="Output")
85
+
86
+ with gr.Accordion(label="Example Inputs and Advanced Generation Parameters"):
87
+ examples=[["./example_images/docvqa_example.png", "How many items are sold?", "Greedy", 0.4, 512, 1.2, 0.8],
88
+ ["./example_images/example_images_travel_tips.jpg", "I want to go somewhere similar to the one in the photo. Give me destinations and travel tips.", "Greedy", 0.4, 512, 1.2, 0.8],
89
+ ["./example_images/baklava.png", "Where is this pastry from?", "Greedy", 0.4, 512, 1.2, 0.8],
90
+ ["./example_images/dummy_pdf.png", "How much percent is the order status?", "Greedy", 0.4, 512, 1.2, 0.8],
91
+ ["./example_images/art_critic.png", "As an art critic AI assistant, could you describe this painting in details and make a thorough critic?.", "Greedy", 0.4, 512, 1.2, 0.8],
92
+ ["./example_images/s2w_example.png", "What is this UI about?", "Greedy", 0.4, 512, 1.2, 0.8]]
93
+
94
+ # Hyper-parameters for generation
95
+ max_new_tokens = gr.Slider(
96
+ minimum=8,
97
+ maximum=1024,
98
+ value=512,
99
+ step=1,
100
+ interactive=True,
101
+ label="Maximum number of new tokens to generate",
102
+ )
103
+ repetition_penalty = gr.Slider(
104
+ minimum=0.01,
105
+ maximum=5.0,
106
+ value=1.2,
107
+ step=0.01,
108
+ interactive=True,
109
+ label="Repetition penalty",
110
+ info="1.0 is equivalent to no penalty",
111
+ )
112
+ temperature = gr.Slider(
113
+ minimum=0.0,
114
+ maximum=5.0,
115
+ value=0.4,
116
+ step=0.1,
117
+ interactive=True,
118
+ label="Sampling temperature",
119
+ info="Higher values will produce more diverse outputs.",
120
+ )
121
+ top_p = gr.Slider(
122
+ minimum=0.01,
123
+ maximum=0.99,
124
+ value=0.8,
125
+ step=0.01,
126
+ interactive=True,
127
+ label="Top P",
128
+ info="Higher values is equivalent to sampling more low-probability tokens.",
129
+ )
130
+ decoding_strategy = gr.Radio(
131
+ [
132
+ "Greedy",
133
+ "Top P Sampling",
134
+ ],
135
+ value="Greedy",
136
+ label="Decoding strategy",
137
+ interactive=True,
138
+ info="Higher values is equivalent to sampling more low-probability tokens.",
139
+ )
140
+ decoding_strategy.change(
141
+ fn=lambda selection: gr.Slider(
142
+ visible=(
143
+ selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"]
144
+ )
145
+ ),
146
+ inputs=decoding_strategy,
147
+ outputs=temperature,
148
+ )
149
+
150
+ decoding_strategy.change(
151
+ fn=lambda selection: gr.Slider(
152
+ visible=(
153
+ selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"]
154
+ )
155
+ ),
156
+ inputs=decoding_strategy,
157
+ outputs=repetition_penalty,
158
+ )
159
+ decoding_strategy.change(
160
+ fn=lambda selection: gr.Slider(visible=(selection in ["Top P Sampling"])),
161
+ inputs=decoding_strategy,
162
+ outputs=top_p,
163
+ )
164
+ gr.Examples(
165
+ examples = examples,
166
+ inputs=[image_input, query_input, decoding_strategy, temperature,
167
+ max_new_tokens, repetition_penalty, top_p],
168
+ outputs=output,
169
+ fn=model_inference
170
+ )
171
+
172
+ submit_btn.click(model_inference, inputs = [image_input, query_input, decoding_strategy, temperature,
173
+ max_new_tokens, repetition_penalty, top_p], outputs=output)
174
+
175
+
176
+ demo.launch(debug=True)