Brice Vandeputte commited on
Commit
6bde7ff
1 Parent(s): a0562b3

Pick bioclip src and adapt demo

Browse files
.gitignore CHANGED
@@ -3,3 +3,4 @@ flagged/
3
  node_modules/
4
  venv/
5
  myenv/
 
 
3
  node_modules/
4
  venv/
5
  myenv/
6
+ __pycache__/
PredictService.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import requests
3
+ from src.bioclip.predict import TreeOfLifeClassifier, Rank
4
+ import logging
5
+
6
+ class PredictService:
7
+
8
+ def __init__(self):
9
+ self.classifier = TreeOfLifeClassifier()
10
+ log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
11
+ logging.basicConfig(level=logging.INFO, format=log_format)
12
+ self.logger = logging.getLogger()
13
+
14
+ def download_image(self, url):
15
+ self.logger.info(f'download_image({url})')
16
+ response = requests.get(url)
17
+
18
+ # Vérifier si la requête a réussi
19
+ if response.status_code == 200:
20
+ # Créer un fichier temporaire
21
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
22
+
23
+ # Écrire le contenu de l'image dans le fichier temporaire
24
+ temp_file.write(response.content)
25
+ temp_file.close()
26
+
27
+ # Retourner le chemin du fichier temporaire
28
+ return temp_file.name
29
+ else:
30
+ raise Exception("Error while downloading image. Status: {}".format(response.status_code))
31
+
32
+ def predict(self, image_url=None):
33
+ if image_url is None:
34
+ raise Exception("expect image url")
35
+ image_path = self.download_image(image_url)
36
+ predictions = self.classifier.predict(image_path, Rank.SPECIES)
37
+ for prediction in predictions:
38
+ if 'file_name' in prediction:
39
+ del prediction['file_name']
40
+ return predictions
app.py CHANGED
@@ -1,35 +1,42 @@
1
  # https://www.gradio.app/guides/sharing-your-app#mounting-within-another-fast-api-app
2
- import logging
3
- import json
4
  import gradio as gr
 
 
 
5
 
6
  log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
7
  logging.basicConfig(level=logging.INFO, format=log_format)
8
  logger = logging.getLogger()
9
 
 
 
10
 
11
  def api_classification(url):
12
- logger.info(f'api_classification({url})')
13
- data = {"status":"WorkInProgress"}
14
- return json.dumps(data)
 
 
15
 
16
 
17
  with gr.Blocks() as app:
18
  with gr.Tab("BioCLIP API"):
19
  with gr.Row():
20
  with gr.Column():
 
 
21
  api_input = gr.Textbox(
22
  placeholder="Image url here",
23
  lines=1,
24
  label="Image url",
25
  show_label=True,
26
  info="Add image url here.",
 
27
  )
28
  api_classification_btn = gr.Button("API", variant="primary")
29
  with gr.Column():
30
  api_classification_output = gr.JSON() # https://www.gradio.app/docs/gradio/json
31
 
32
-
33
  api_classification_btn.click(
34
  fn=api_classification,
35
  inputs=[api_input],
 
1
  # https://www.gradio.app/guides/sharing-your-app#mounting-within-another-fast-api-app
 
 
2
  import gradio as gr
3
+ import json
4
+ import logging
5
+ from PredictService import PredictService
6
 
7
  log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
8
  logging.basicConfig(level=logging.INFO, format=log_format)
9
  logger = logging.getLogger()
10
 
11
+ svc = PredictService()
12
+
13
 
14
  def api_classification(url):
15
+ url_to_use = url
16
+ if url_to_use == "exemple":
17
+ url_to_use = "https://images.pexels.com/photos/326900/pexels-photo-326900.jpeg?cs=srgb&dl=pexels-pixabay-326900.jpg&fm=jpg"
18
+ predictions = svc.predict(url_to_use)
19
+ return json.dumps(predictions)
20
 
21
 
22
  with gr.Blocks() as app:
23
  with gr.Tab("BioCLIP API"):
24
  with gr.Row():
25
  with gr.Column():
26
+ # https://www.gradio.app/guides/key-component-concepts
27
+ gr.Textbox(value="This is a BioCLIP based prediction. You must input a public url of an image and you will get TreeOfLife predictions as result", interactive=False)
28
  api_input = gr.Textbox(
29
  placeholder="Image url here",
30
  lines=1,
31
  label="Image url",
32
  show_label=True,
33
  info="Add image url here.",
34
+ value="https://natureconservancy-h.assetsadobe.com/is/image/content/dam/tnc/nature/en/photos/d/o/Downy-woodpecker-Matt-Williams.jpg?crop=0%2C39%2C3097%2C2322&wid=820&hei=615&scl=3.776829268292683"
35
  )
36
  api_classification_btn = gr.Button("API", variant="primary")
37
  with gr.Column():
38
  api_classification_output = gr.JSON() # https://www.gradio.app/docs/gradio/json
39
 
 
40
  api_classification_btn.click(
41
  fn=api_classification,
42
  inputs=[api_input],
requirements.txt CHANGED
@@ -1,2 +1,6 @@
1
  huggingface_hub==0.22.2
2
- gradio
 
 
 
 
 
1
  huggingface_hub==0.22.2
2
+ gradio
3
+ # bioclip deps
4
+ open_clip_torch
5
+ torchvision
6
+ torch
src/bioclip/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: 2024-present John Bradley <[email protected]>
2
+ #
3
+ # SPDX-License-Identifier: MIT
4
+ from .predict import TreeOfLifeClassifier, Rank, CustomLabelsClassifier
5
+
6
+ __all__ = ["TreeOfLifeClassifier", "Rank", "CustomLabelsClassifier"]
src/bioclip/predict.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ from torchvision import transforms
4
+ from open_clip import create_model, get_tokenizer
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ import collections
8
+ import heapq
9
+ import PIL.Image
10
+ from huggingface_hub import hf_hub_download
11
+ from typing import Union, List
12
+ from enum import Enum
13
+
14
+
15
+ HF_DATAFILE_REPO = "imageomics/bioclip-demo"
16
+ HF_DATAFILE_REPO_TYPE = "space"
17
+ PRED_FILENAME_KEY = "file_name"
18
+ PRED_CLASSICATION_KEY = "classification"
19
+ PRED_SCORE_KEY = "score"
20
+
21
+ OPENA_AI_IMAGENET_TEMPLATE = [
22
+ lambda c: f"a bad photo of a {c}.",
23
+ lambda c: f"a photo of many {c}.",
24
+ lambda c: f"a sculpture of a {c}.",
25
+ lambda c: f"a photo of the hard to see {c}.",
26
+ lambda c: f"a low resolution photo of the {c}.",
27
+ lambda c: f"a rendering of a {c}.",
28
+ lambda c: f"graffiti of a {c}.",
29
+ lambda c: f"a bad photo of the {c}.",
30
+ lambda c: f"a cropped photo of the {c}.",
31
+ lambda c: f"a tattoo of a {c}.",
32
+ lambda c: f"the embroidered {c}.",
33
+ lambda c: f"a photo of a hard to see {c}.",
34
+ lambda c: f"a bright photo of a {c}.",
35
+ lambda c: f"a photo of a clean {c}.",
36
+ lambda c: f"a photo of a dirty {c}.",
37
+ lambda c: f"a dark photo of the {c}.",
38
+ lambda c: f"a drawing of a {c}.",
39
+ lambda c: f"a photo of my {c}.",
40
+ lambda c: f"the plastic {c}.",
41
+ lambda c: f"a photo of the cool {c}.",
42
+ lambda c: f"a close-up photo of a {c}.",
43
+ lambda c: f"a black and white photo of the {c}.",
44
+ lambda c: f"a painting of the {c}.",
45
+ lambda c: f"a painting of a {c}.",
46
+ lambda c: f"a pixelated photo of the {c}.",
47
+ lambda c: f"a sculpture of the {c}.",
48
+ lambda c: f"a bright photo of the {c}.",
49
+ lambda c: f"a cropped photo of a {c}.",
50
+ lambda c: f"a plastic {c}.",
51
+ lambda c: f"a photo of the dirty {c}.",
52
+ lambda c: f"a jpeg corrupted photo of a {c}.",
53
+ lambda c: f"a blurry photo of the {c}.",
54
+ lambda c: f"a photo of the {c}.",
55
+ lambda c: f"a good photo of the {c}.",
56
+ lambda c: f"a rendering of the {c}.",
57
+ lambda c: f"a {c} in a video game.",
58
+ lambda c: f"a photo of one {c}.",
59
+ lambda c: f"a doodle of a {c}.",
60
+ lambda c: f"a close-up photo of the {c}.",
61
+ lambda c: f"a photo of a {c}.",
62
+ lambda c: f"the origami {c}.",
63
+ lambda c: f"the {c} in a video game.",
64
+ lambda c: f"a sketch of a {c}.",
65
+ lambda c: f"a doodle of the {c}.",
66
+ lambda c: f"a origami {c}.",
67
+ lambda c: f"a low resolution photo of a {c}.",
68
+ lambda c: f"the toy {c}.",
69
+ lambda c: f"a rendition of the {c}.",
70
+ lambda c: f"a photo of the clean {c}.",
71
+ lambda c: f"a photo of a large {c}.",
72
+ lambda c: f"a rendition of a {c}.",
73
+ lambda c: f"a photo of a nice {c}.",
74
+ lambda c: f"a photo of a weird {c}.",
75
+ lambda c: f"a blurry photo of a {c}.",
76
+ lambda c: f"a cartoon {c}.",
77
+ lambda c: f"art of a {c}.",
78
+ lambda c: f"a sketch of the {c}.",
79
+ lambda c: f"a embroidered {c}.",
80
+ lambda c: f"a pixelated photo of a {c}.",
81
+ lambda c: f"itap of the {c}.",
82
+ lambda c: f"a jpeg corrupted photo of the {c}.",
83
+ lambda c: f"a good photo of a {c}.",
84
+ lambda c: f"a plushie {c}.",
85
+ lambda c: f"a photo of the nice {c}.",
86
+ lambda c: f"a photo of the small {c}.",
87
+ lambda c: f"a photo of the weird {c}.",
88
+ lambda c: f"the cartoon {c}.",
89
+ lambda c: f"art of the {c}.",
90
+ lambda c: f"a drawing of the {c}.",
91
+ lambda c: f"a photo of the large {c}.",
92
+ lambda c: f"a black and white photo of a {c}.",
93
+ lambda c: f"the plushie {c}.",
94
+ lambda c: f"a dark photo of a {c}.",
95
+ lambda c: f"itap of a {c}.",
96
+ lambda c: f"graffiti of the {c}.",
97
+ lambda c: f"a toy {c}.",
98
+ lambda c: f"itap of my {c}.",
99
+ lambda c: f"a photo of a cool {c}.",
100
+ lambda c: f"a photo of a small {c}.",
101
+ lambda c: f"a tattoo of the {c}.",
102
+ ]
103
+
104
+
105
+ def get_cached_datafile(filename:str):
106
+ return hf_hub_download(repo_id=HF_DATAFILE_REPO, filename=filename, repo_type=HF_DATAFILE_REPO_TYPE)
107
+
108
+
109
+ def get_txt_emb():
110
+ txt_emb_npy = get_cached_datafile("txt_emb_species.npy")
111
+ return torch.from_numpy(np.load(txt_emb_npy))
112
+
113
+
114
+ def get_txt_names():
115
+ txt_names_json = get_cached_datafile("txt_emb_species.json")
116
+ with open(txt_names_json) as fd:
117
+ txt_names = json.load(fd)
118
+ return txt_names
119
+
120
+
121
+ preprocess_img = transforms.Compose(
122
+ [
123
+ transforms.ToTensor(),
124
+ transforms.Resize((224, 224), antialias=True),
125
+ transforms.Normalize(
126
+ mean=(0.48145466, 0.4578275, 0.40821073),
127
+ std=(0.26862954, 0.26130258, 0.27577711),
128
+ ),
129
+ ]
130
+ )
131
+
132
+ class Rank(Enum):
133
+ KINGDOM = 0
134
+ PHYLUM = 1
135
+ CLASS = 2
136
+ ORDER = 3
137
+ FAMILY = 4
138
+ GENUS = 5
139
+ SPECIES = 6
140
+
141
+ def get_label(self):
142
+ return self.name.lower()
143
+
144
+
145
+ # The datafile of names ('txt_emb_species.json') contains species epithet.
146
+ # To create a label for species we concatenate the genus and species epithet.
147
+ SPECIES_LABEL = Rank.SPECIES.get_label()
148
+ SPECIES_EPITHET_LABEL = "species_epithet"
149
+ COMMON_NAME_LABEL = "common_name"
150
+
151
+
152
+ def create_bioclip_model(model_str="hf-hub:imageomics/bioclip", device="cuda"):
153
+ model = create_model(model_str, output_dict=True, require_pretrained=True)
154
+ model = model.to(device)
155
+ return torch.compile(model)
156
+
157
+
158
+ def create_bioclip_tokenizer(tokenizer_str="ViT-B-16"):
159
+ return get_tokenizer(tokenizer_str)
160
+
161
+
162
+ class CustomLabelsClassifier(object):
163
+ def __init__(self, device: Union[str, torch.device] = 'cpu'):
164
+ self.device = device
165
+ self.model = create_bioclip_model(device=device)
166
+ self.tokenizer = create_bioclip_tokenizer()
167
+
168
+ def get_txt_features(self, classnames):
169
+ all_features = []
170
+ for classname in classnames:
171
+ txts = [template(classname) for template in OPENA_AI_IMAGENET_TEMPLATE]
172
+ txts = self.tokenizer(txts).to(self.device)
173
+ txt_features = self.model.encode_text(txts)
174
+ txt_features = F.normalize(txt_features, dim=-1).mean(dim=0)
175
+ txt_features /= txt_features.norm()
176
+ all_features.append(txt_features)
177
+ all_features = torch.stack(all_features, dim=1)
178
+ return all_features
179
+
180
+ @torch.no_grad()
181
+ def predict(self, image_path: str, cls_ary: List[str]) -> dict[str, float]:
182
+ img = PIL.Image.open(image_path)
183
+ classes = [cls.strip() for cls in cls_ary]
184
+ txt_features = self.get_txt_features(classes)
185
+
186
+ img = preprocess_img(img).to(self.device)
187
+ img_features = self.model.encode_image(img.unsqueeze(0))
188
+ img_features = F.normalize(img_features, dim=-1)
189
+
190
+ logits = (self.model.logit_scale.exp() * img_features @ txt_features).squeeze()
191
+ probs = F.softmax(logits, dim=0).to("cpu").tolist()
192
+ pred_list = []
193
+ for cls, prob in zip(classes, probs):
194
+ pred_list.append({
195
+ PRED_FILENAME_KEY: image_path,
196
+ PRED_CLASSICATION_KEY: cls,
197
+ PRED_SCORE_KEY: prob
198
+ })
199
+ return pred_list
200
+
201
+
202
+ def predict_classifications_from_list(img: Union[PIL.Image.Image, str], cls_ary: List[str], device: Union[str, torch.device] = 'cpu') -> dict[str, float]:
203
+ classifier = CustomLabelsClassifier(device=device)
204
+ return classifier.predict(img, cls_ary)
205
+
206
+
207
+ def get_tol_classification_labels(rank: Rank) -> List[str]:
208
+ names = []
209
+ for i in range(rank.value + 1):
210
+ i_rank = Rank(i)
211
+ if i_rank == Rank.SPECIES:
212
+ names.append(SPECIES_EPITHET_LABEL)
213
+ rank_name = i_rank.name.lower()
214
+ names.append(rank_name)
215
+ if rank == Rank.SPECIES:
216
+ names.append(COMMON_NAME_LABEL)
217
+ return names
218
+
219
+
220
+ def create_classification_dict(names: List[List[str]], rank: Rank) -> dict[str, str]:
221
+ scientific_names = names[0]
222
+ common_name = names[1]
223
+ classification_dict = {}
224
+ for idx, label in enumerate(get_tol_classification_labels(rank=rank)):
225
+ if label == SPECIES_LABEL:
226
+ value = scientific_names[-2] + " " + scientific_names[-1]
227
+ elif label == COMMON_NAME_LABEL:
228
+ value = common_name
229
+ else:
230
+ value = scientific_names[idx]
231
+ classification_dict[label] = value
232
+ return classification_dict
233
+
234
+
235
+ def join_names(classification_dict: dict[str, str]) -> str:
236
+ return " ".join(classification_dict.values())
237
+
238
+
239
+ class TreeOfLifeClassifier(object):
240
+ def __init__(self, device: Union[str, torch.device] = 'cpu'):
241
+ self.device = device
242
+ self.model = create_bioclip_model(device=device)
243
+ self.txt_emb = get_txt_emb().to(device)
244
+ self.txt_names = get_txt_names()
245
+
246
+ def encode_image(self, img: PIL.Image.Image) -> torch.Tensor:
247
+ img = preprocess_img(img).to(self.device)
248
+ img_features = self.model.encode_image(img.unsqueeze(0))
249
+ return img_features
250
+
251
+ def predict_species(self, img: PIL.Image.Image) -> torch.Tensor:
252
+ img_features = self.encode_image(img)
253
+ img_features = F.normalize(img_features, dim=-1)
254
+ logits = (self.model.logit_scale.exp() * img_features @ self.txt_emb).squeeze()
255
+ probs = F.softmax(logits, dim=0)
256
+ return probs
257
+
258
+ def format_species_probs(self, image_path: str, probs: torch.Tensor, k: int = 5) -> List[dict[str, float]]:
259
+ topk = probs.topk(k)
260
+ result = []
261
+ for i, prob in zip(topk.indices, topk.values):
262
+ item = { PRED_FILENAME_KEY: image_path }
263
+ item.update(create_classification_dict(self.txt_names[i], Rank.SPECIES))
264
+ item[PRED_SCORE_KEY] = prob.item()
265
+ result.append(item)
266
+ return result
267
+
268
+ def format_grouped_probs(self, image_path: str, probs: torch.Tensor, rank: Rank, min_prob: float = 1e-9, k: int = 5) -> List[dict[str, float]]:
269
+ output = collections.defaultdict(float)
270
+ class_dict_lookup = {}
271
+ name_to_class_dict = {}
272
+ for i in torch.nonzero(probs > min_prob).squeeze():
273
+ classification_dict = create_classification_dict(self.txt_names[i], rank)
274
+ name = join_names(classification_dict)
275
+ class_dict_lookup[name] = classification_dict
276
+ output[name] += probs[i]
277
+ name_to_class_dict[name] = classification_dict
278
+ topk_names = heapq.nlargest(k, output, key=output.get)
279
+ prediction_ary = []
280
+ for name in topk_names:
281
+ item = { PRED_FILENAME_KEY: image_path }
282
+ item.update(name_to_class_dict[name])
283
+ #item.update(class_dict_lookup)
284
+ item[PRED_SCORE_KEY] = output[name].item()
285
+ prediction_ary.append(item)
286
+ return prediction_ary
287
+
288
+ @torch.no_grad()
289
+ def predict(self, image_path: str, rank: Rank, min_prob: float = 1e-9, k: int = 5) -> List[dict[str, float]]:
290
+ img = PIL.Image.open(image_path)
291
+ probs = self.predict_species(img)
292
+ if rank == Rank.SPECIES:
293
+ return self.format_species_probs(image_path, probs, k)
294
+ return self.format_grouped_probs(image_path, probs, rank, min_prob, k)
295
+
296
+
297
+ def predict_classification(img: str, rank: Rank, device: Union[str, torch.device] = 'cpu',
298
+ min_prob: float = 1e-9, k: int = 5) -> dict[str, float]:
299
+ """
300
+ Predicts from the entire tree of life.
301
+ If targeting a higher rank than species, then this function predicts among all
302
+ species, then sums up species-level probabilities for the given rank.
303
+ """
304
+ classifier = TreeOfLifeClassifier(device=device)
305
+ return classifier.predict(img, rank, min_prob, k)