mattraj commited on
Commit
9228783
1 Parent(s): ecdedbd

demo buildout

Browse files
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/curacel-demo-1.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
5
+ </profile>
6
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/curacel-demo-1.iml" filepath="$PROJECT_DIR$/.idea/curacel-demo-1.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
app.py CHANGED
@@ -1,63 +1,336 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
 
26
- messages.append({"role": "user", "content": message})
27
 
28
- response = ""
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
 
 
 
 
 
 
 
 
38
 
39
- response += token
40
- yield response
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  )
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  if __name__ == "__main__":
63
- demo.launch()
 
1
  import gradio as gr
2
+ import PIL.Image
3
+ import transformers
4
+ from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
5
+ import torch
6
+ import os
7
+ import string
8
+ import functools
9
+ import re
10
+ import flax.linen as nn
11
+ import jax
12
+ import jax.numpy as jnp
13
+ import numpy as np
14
+ import spaces
15
 
16
+ model_id = "mattraj/curacel-transcription-1"
17
+ COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(device)
20
+ processor = PaliGemmaProcessor.from_pretrained(model_id)
21
 
22
+ def resize_and_pad(image, target_dim):
23
+ # Calculate the aspect ratio
24
+ scale_factor = 1
25
+ aspect_ratio = image.width / image.height
26
+ if aspect_ratio > 1:
27
+ # Width is greater than height
28
+ new_width = int(target_dim * scale_factor)
29
+ new_height = int((target_dim / aspect_ratio) * scale_factor)
30
+ else:
31
+ # Height is greater than width
32
+ new_height = int(target_dim * scale_factor)
33
+ new_width = int(target_dim * aspect_ratio * scale_factor)
34
 
35
+ resized_image = image.resize((new_width, new_height), Image.LANCZOS)
 
 
 
 
 
 
 
 
36
 
37
+ # Create a new image with the target dimensions and a white background
38
+ new_image = Image.new("RGB", (target_dim, target_dim), (255, 255, 255))
39
+ new_image.paste(resized_image, ((target_dim - new_width) // 2, (target_dim - new_height) // 2))
 
 
40
 
41
+ return new_image
42
 
 
43
 
44
+ ###### Transformers Inference
45
+ @spaces.GPU
46
+ def infer(
47
+ image: PIL.Image.Image,
48
+ text: str,
49
+ max_new_tokens: int
50
+ ) -> str:
51
+ inputs = processor(text=text, images=resize_and_pad(image), return_tensors="pt").to(device)
52
+ with torch.inference_mode():
53
+ generated_ids = model.generate(
54
+ **inputs,
55
+ max_new_tokens=max_new_tokens,
56
+ do_sample=False
57
+ )
58
+ result = processor.batch_decode(generated_ids, skip_special_tokens=True)
59
+ return result[0][len(text):].lstrip("\n")
60
 
 
 
61
 
62
+ ##### Parse segmentation output tokens into masks
63
+ ##### Also returns bounding boxes with their labels
64
+
65
+ def parse_segmentation(input_image, input_text):
66
+ out = infer(input_image, input_text, max_new_tokens=100)
67
+ objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True)
68
+ labels = set(obj.get('name') for obj in objs if obj.get('name'))
69
+ color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
70
+ highlighted_text = [(obj['content'], obj.get('name')) for obj in objs]
71
+ annotated_img = (
72
+ input_image,
73
+ [
74
+ (
75
+ obj['mask'] if obj.get('mask') is not None else obj['xyxy'],
76
+ obj['name'] or '',
77
+ )
78
+ for obj in objs
79
+ if 'mask' in obj or 'xyxy' in obj
80
+ ],
81
+ )
82
+ has_annotations = bool(annotated_img[1])
83
+ return annotated_img
84
+
85
+
86
+ ######## Demo
87
+
88
+ INTRO_TEXT = """## Curacel Handwritten Arabic demo\n\n
89
+ Finetuned from: google/paligemma-3b-pt-448
90
+
91
+
92
+ Translation model demo at: https://prod.arabic-gpt.ai/
93
+
94
+ Prompts:
95
+ Translate the Arabic to English: {model output}
96
+
97
+ The following is a diagnosis in Arabic from a medical billing form we need to translate to English. The transcriber is not necessariily accurate so one or more characters or words may be wrong. Given what is written, what is the most likely diagnosis. Think step by step, and think about similar words or mispellings in Arabic. Give multiple arabic diagnoses along with the translation in English for each, then finally select the diagnosis that makes the most sense given what was transcribed and print the English translation as your most likely final translation. Transcribed text: {model output}
98
  """
99
+
100
+ with gr.Blocks(css="style.css") as demo:
101
+ gr.Markdown(INTRO_TEXT)
102
+ with gr.Tab("Text Generation"):
103
+ with gr.Column():
104
+ image = gr.Image(type="pil")
105
+ text_input = gr.Text(label="Input Text")
106
+
107
+ text_output = gr.Text(label="Text Output")
108
+ chat_btn = gr.Button()
109
+ tokens = gr.Slider(
110
+ label="Max New Tokens",
111
+ info="Set to larger for longer generation.",
112
+ minimum=10,
113
+ maximum=100,
114
+ value=20,
115
+ step=10,
116
+ )
117
+
118
+ chat_inputs = [
119
+ image,
120
+ text_input,
121
+ tokens
122
+ ]
123
+ chat_outputs = [
124
+ text_output
125
+ ]
126
+ chat_btn.click(
127
+ fn=infer,
128
+ inputs=chat_inputs,
129
+ outputs=chat_outputs,
130
+ )
131
+
132
+ examples = [["./diagnosis-1.jpg", "Transcribe the Arabic text."],
133
+ ["./examples/sign.jpg", "Transcribe the Arabic text."]]
134
+ gr.Markdown("")
135
+
136
+ gr.Examples(
137
+ examples=examples,
138
+ inputs=chat_inputs,
139
+ )
140
+ '''
141
+ with gr.Tab("Segment/Detect"):
142
+ image = gr.Image(type="pil")
143
+ seg_input = gr.Text(label="Entities to Segment/Detect")
144
+ seg_btn = gr.Button("Submit")
145
+ annotated_image = gr.AnnotatedImage(label="Output")
146
+
147
+ examples = [["./diagnosis-1.jpg", "Transcribe the Arabic text."],
148
+ ["./examples/sign.jpg", "Transcribe the Arabic text."]]
149
+ gr.Markdown(
150
+ "")
151
+ gr.Examples(
152
+ examples=examples,
153
+ inputs=[image, seg_input],
154
+ )
155
+
156
+ seg_inputs = [
157
+ image,
158
+ seg_input
159
+ ]
160
+ seg_outputs = [
161
+ annotated_image
162
+ ]
163
+ seg_btn.click(
164
+ fn=parse_segmentation,
165
+ inputs=seg_inputs,
166
+ outputs=seg_outputs,
167
+ )
168
+ '''
169
+
170
+ ### Postprocessing Utils for Segmentation Tokens
171
+ ### Segmentation tokens are passed to another VAE which decodes them to a mask
172
+
173
+ _MODEL_PATH = 'vae-oid.npz'
174
+
175
+ _SEGMENT_DETECT_RE = re.compile(
176
+ r'(.*?)' +
177
+ r'<loc(\d{4})>' * 4 + r'\s*' +
178
+ '(?:%s)?' % (r'<seg(\d{3})>' * 16) +
179
+ r'\s*([^;<>]+)? ?(?:; )?',
180
  )
181
 
182
 
183
+ def _get_params(checkpoint):
184
+ """Converts PyTorch checkpoint to Flax params."""
185
+
186
+ def transp(kernel):
187
+ return np.transpose(kernel, (2, 3, 1, 0))
188
+
189
+ def conv(name):
190
+ return {
191
+ 'bias': checkpoint[name + '.bias'],
192
+ 'kernel': transp(checkpoint[name + '.weight']),
193
+ }
194
+
195
+ def resblock(name):
196
+ return {
197
+ 'Conv_0': conv(name + '.0'),
198
+ 'Conv_1': conv(name + '.2'),
199
+ 'Conv_2': conv(name + '.4'),
200
+ }
201
+
202
+ return {
203
+ '_embeddings': checkpoint['_vq_vae._embedding'],
204
+ 'Conv_0': conv('decoder.0'),
205
+ 'ResBlock_0': resblock('decoder.2.net'),
206
+ 'ResBlock_1': resblock('decoder.3.net'),
207
+ 'ConvTranspose_0': conv('decoder.4'),
208
+ 'ConvTranspose_1': conv('decoder.6'),
209
+ 'ConvTranspose_2': conv('decoder.8'),
210
+ 'ConvTranspose_3': conv('decoder.10'),
211
+ 'Conv_1': conv('decoder.12'),
212
+ }
213
+
214
+
215
+ def _quantized_values_from_codebook_indices(codebook_indices, embeddings):
216
+ batch_size, num_tokens = codebook_indices.shape
217
+ assert num_tokens == 16, codebook_indices.shape
218
+ unused_num_embeddings, embedding_dim = embeddings.shape
219
+
220
+ encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0)
221
+ encodings = encodings.reshape((batch_size, 4, 4, embedding_dim))
222
+ return encodings
223
+
224
+
225
+ @functools.cache
226
+ def _get_reconstruct_masks():
227
+ """Reconstructs masks from codebook indices.
228
+ Returns:
229
+ A function that expects indices shaped `[B, 16]` of dtype int32, each
230
+ ranging from 0 to 127 (inclusive), and that returns a decoded masks sized
231
+ `[B, 64, 64, 1]`, of dtype float32, in range [-1, 1].
232
+ """
233
+
234
+ class ResBlock(nn.Module):
235
+ features: int
236
+
237
+ @nn.compact
238
+ def __call__(self, x):
239
+ original_x = x
240
+ x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
241
+ x = nn.relu(x)
242
+ x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
243
+ x = nn.relu(x)
244
+ x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x)
245
+ return x + original_x
246
+
247
+ class Decoder(nn.Module):
248
+ """Upscales quantized vectors to mask."""
249
+
250
+ @nn.compact
251
+ def __call__(self, x):
252
+ num_res_blocks = 2
253
+ dim = 128
254
+ num_upsample_layers = 4
255
+
256
+ x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x)
257
+ x = nn.relu(x)
258
+
259
+ for _ in range(num_res_blocks):
260
+ x = ResBlock(features=dim)(x)
261
+
262
+ for _ in range(num_upsample_layers):
263
+ x = nn.ConvTranspose(
264
+ features=dim,
265
+ kernel_size=(4, 4),
266
+ strides=(2, 2),
267
+ padding=2,
268
+ transpose_kernel=True,
269
+ )(x)
270
+ x = nn.relu(x)
271
+ dim //= 2
272
+
273
+ x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x)
274
+
275
+ return x
276
+
277
+ def reconstruct_masks(codebook_indices):
278
+ quantized = _quantized_values_from_codebook_indices(
279
+ codebook_indices, params['_embeddings']
280
+ )
281
+ return Decoder().apply({'params': params}, quantized)
282
+
283
+ with open(_MODEL_PATH, 'rb') as f:
284
+ params = _get_params(dict(np.load(f)))
285
+
286
+ return jax.jit(reconstruct_masks, backend='cpu')
287
+
288
+
289
+ def extract_objs(text, width, height, unique_labels=False):
290
+ """Returns objs for a string with "<loc>" and "<seg>" tokens."""
291
+ objs = []
292
+ seen = set()
293
+ while text:
294
+ m = _SEGMENT_DETECT_RE.match(text)
295
+ if not m:
296
+ break
297
+ print("m", m)
298
+ gs = list(m.groups())
299
+ before = gs.pop(0)
300
+ name = gs.pop()
301
+ y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
302
+
303
+ y1, x1, y2, x2 = map(round, (y1 * height, x1 * width, y2 * height, x2 * width))
304
+ seg_indices = gs[4:20]
305
+ if seg_indices[0] is None:
306
+ mask = None
307
+ else:
308
+ seg_indices = np.array([int(x) for x in seg_indices], dtype=np.int32)
309
+ m64, = _get_reconstruct_masks()(seg_indices[None])[..., 0]
310
+ m64 = np.clip(np.array(m64) * 0.5 + 0.5, 0, 1)
311
+ m64 = PIL.Image.fromarray((m64 * 255).astype('uint8'))
312
+ mask = np.zeros([height, width])
313
+ if y2 > y1 and x2 > x1:
314
+ mask[y1:y2, x1:x2] = np.array(m64.resize([x2 - x1, y2 - y1])) / 255.0
315
+
316
+ content = m.group()
317
+ if before:
318
+ objs.append(dict(content=before))
319
+ content = content[len(before):]
320
+ while unique_labels and name in seen:
321
+ name = (name or '') + "'"
322
+ seen.add(name)
323
+ objs.append(dict(
324
+ content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name))
325
+ text = text[len(before) + len(content):]
326
+
327
+ if text:
328
+ objs.append(dict(content=text))
329
+
330
+ return objs
331
+
332
+
333
+ #########
334
+
335
  if __name__ == "__main__":
336
+ demo.queue(max_size=10).launch(debug=True)
diagnosis-1.png ADDED
sign.png ADDED