thompsonmj commited on
Commit
6277b48
·
1 Parent(s): 30b4d4e

Retrieve example TOL-10M image and representative EOL page for OE prediction

Browse files

Co-authored by: Elizabeth Campolongo <[email protected]>

Files changed (3) hide show
  1. app.py +71 -12
  2. components/query.py +115 -0
  3. requirements.txt +3 -0
app.py CHANGED
@@ -6,12 +6,14 @@ import logging
6
 
7
  import gradio as gr
8
  import numpy as np
 
9
  import torch
10
  import torch.nn.functional as F
11
  from open_clip import create_model, get_tokenizer
12
  from torchvision import transforms
13
 
14
  from templates import openai_imagenet_template
 
15
 
16
  log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
17
  logging.basicConfig(level=logging.INFO, format=log_format)
@@ -19,6 +21,12 @@ logger = logging.getLogger()
19
 
20
  hf_token = os.getenv("HF_TOKEN")
21
 
 
 
 
 
 
 
22
  model_str = "hf-hub:imageomics/bioclip"
23
  tokenizer_str = "ViT-B-16"
24
 
@@ -123,12 +131,14 @@ def format_name(taxon, common):
123
 
124
 
125
  @torch.no_grad()
126
- def open_domain_classification(img, rank: int) -> dict[str, float]:
127
  """
128
  Predicts from the entire tree of life.
129
  If targeting a higher rank than species, then this function predicts among all
130
  species, then sums up species-level probabilities for the given rank.
131
  """
 
 
132
  img = preprocess_img(img).to(device)
133
  img_features = model.encode_image(img.unsqueeze(0))
134
  img_features = F.normalize(img_features, dim=-1)
@@ -136,21 +146,36 @@ def open_domain_classification(img, rank: int) -> dict[str, float]:
136
  logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
137
  probs = F.softmax(logits, dim=0)
138
 
139
- # If predicting species, no need to sum probabilities.
140
  if rank + 1 == len(ranks):
141
  topk = probs.topk(k)
142
- return {
143
  format_name(*txt_names[i]): prob for i, prob in zip(topk.indices, topk.values)
144
  }
 
 
 
 
 
 
 
145
 
146
- # Sum up by the rank
147
  output = collections.defaultdict(float)
148
  for i in torch.nonzero(probs > min_prob).squeeze():
149
  output[" ".join(txt_names[i][0][: rank + 1])] += probs[i]
150
 
151
  topk_names = heapq.nlargest(k, output, key=output.get)
 
 
 
152
 
153
- return {name: output[name] for name in topk_names}
 
 
 
 
 
 
 
154
 
155
 
156
  def change_output(choice):
@@ -179,9 +204,19 @@ if __name__ == "__main__":
179
  status_msg = f"{done}/{total} ({done / total * 100:.1f}%) indexed"
180
 
181
  with gr.Blocks() as app:
182
- img_input = gr.Image()
183
-
184
  with gr.Tab("Open-Ended"):
 
 
 
 
 
 
 
 
 
 
 
185
  with gr.Row():
186
  with gr.Column():
187
  rank_dropdown = gr.Dropdown(
@@ -201,12 +236,17 @@ if __name__ == "__main__":
201
  )
202
  # open_domain_flag_btn = gr.Button("Flag Mistake", variant="primary")
203
 
 
 
 
 
 
204
  with gr.Row():
205
  gr.Examples(
206
  examples=open_domain_examples,
207
  inputs=[img_input, rank_dropdown],
208
  cache_examples=True,
209
- fn=open_domain_classification,
210
  outputs=[open_domain_output],
211
  )
212
  '''
@@ -225,6 +265,9 @@ if __name__ == "__main__":
225
  )
226
  '''
227
  with gr.Tab("Zero-Shot"):
 
 
 
228
  with gr.Row():
229
  with gr.Column():
230
  classes_txt = gr.Textbox(
@@ -245,7 +288,7 @@ if __name__ == "__main__":
245
  with gr.Row():
246
  gr.Examples(
247
  examples=zero_shot_examples,
248
- inputs=[img_input, classes_txt],
249
  cache_examples=True,
250
  fn=zero_shot_classification,
251
  outputs=[zero_shot_output],
@@ -268,17 +311,33 @@ if __name__ == "__main__":
268
  fn=change_output, inputs=rank_dropdown, outputs=[open_domain_output]
269
  )
270
 
 
 
 
 
 
 
271
  open_domain_btn.click(
272
- fn=open_domain_classification,
273
  inputs=[img_input, rank_dropdown],
274
- outputs=[open_domain_output],
275
  )
276
 
277
  zero_shot_btn.click(
278
  fn=zero_shot_classification,
279
- inputs=[img_input, classes_txt],
280
  outputs=zero_shot_output,
281
  )
 
 
 
 
 
 
 
 
 
 
282
 
283
  app.queue(max_size=20)
284
  app.launch()
 
6
 
7
  import gradio as gr
8
  import numpy as np
9
+ import polars as pl
10
  import torch
11
  import torch.nn.functional as F
12
  from open_clip import create_model, get_tokenizer
13
  from torchvision import transforms
14
 
15
  from templates import openai_imagenet_template
16
+ from components.query import get_sample
17
 
18
  log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
19
  logging.basicConfig(level=logging.INFO, format=log_format)
 
21
 
22
  hf_token = os.getenv("HF_TOKEN")
23
 
24
+ # For sample images
25
+ METADATA_PATH = "components/metadata.csv"
26
+ # Read page ID as int and filter out smaller ablation duplicated training split
27
+ metadata_df = pl.read_csv(METADATA_PATH, low_memory = False)
28
+ metadata_df = metadata_df.with_columns(pl.col("eol_page_id").cast(pl.Int64))
29
+
30
  model_str = "hf-hub:imageomics/bioclip"
31
  tokenizer_str = "ViT-B-16"
32
 
 
131
 
132
 
133
  @torch.no_grad()
134
+ def open_domain_classification(img, rank: int, return_all=False):
135
  """
136
  Predicts from the entire tree of life.
137
  If targeting a higher rank than species, then this function predicts among all
138
  species, then sums up species-level probabilities for the given rank.
139
  """
140
+
141
+ logger.info(f"Starting open domain classification for rank: {rank}")
142
  img = preprocess_img(img).to(device)
143
  img_features = model.encode_image(img.unsqueeze(0))
144
  img_features = F.normalize(img_features, dim=-1)
 
146
  logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
147
  probs = F.softmax(logits, dim=0)
148
 
 
149
  if rank + 1 == len(ranks):
150
  topk = probs.topk(k)
151
+ prediction_dict = {
152
  format_name(*txt_names[i]): prob for i, prob in zip(topk.indices, topk.values)
153
  }
154
+ logger.info(f"Top K predictions: {prediction_dict}")
155
+ top_prediction_name = format_name(*txt_names[topk.indices[0]]).split("(")[0]
156
+ logger.info(f"Top prediction name: {top_prediction_name}")
157
+ sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
158
+ if return_all:
159
+ return prediction_dict, sample_img, taxon_url
160
+ return prediction_dict
161
 
 
162
  output = collections.defaultdict(float)
163
  for i in torch.nonzero(probs > min_prob).squeeze():
164
  output[" ".join(txt_names[i][0][: rank + 1])] += probs[i]
165
 
166
  topk_names = heapq.nlargest(k, output, key=output.get)
167
+ prediction_dict = {name: output[name] for name in topk_names}
168
+ logger.info(f"Top K names for output: {topk_names}")
169
+ logger.info(f"Prediction dictionary: {prediction_dict}")
170
 
171
+ top_prediction_name = topk_names[0]
172
+ logger.info(f"Top prediction name: {top_prediction_name}")
173
+ sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
174
+ logger.info(f"Sample image and taxon URL: {sample_img}, {taxon_url}")
175
+
176
+ if return_all:
177
+ return prediction_dict, sample_img, taxon_url
178
+ return prediction_dict
179
 
180
 
181
  def change_output(choice):
 
204
  status_msg = f"{done}/{total} ({done / total * 100:.1f}%) indexed"
205
 
206
  with gr.Blocks() as app:
207
+
 
208
  with gr.Tab("Open-Ended"):
209
+ # with gr.Row(variant = "panel", elem_id = "images_panel"):
210
+ with gr.Row(variant = "panel", elem_id = "images_panel"):
211
+ with gr.Column():
212
+ img_input = gr.Image(height = 400, sources=["upload"])
213
+
214
+ with gr.Column():
215
+ # display sample image of top predicted taxon
216
+ sample_img = gr.Image(label = "Sample Image of Predicted Taxon",
217
+ height = 400,
218
+ show_download_button = False)
219
+
220
  with gr.Row():
221
  with gr.Column():
222
  rank_dropdown = gr.Dropdown(
 
236
  )
237
  # open_domain_flag_btn = gr.Button("Flag Mistake", variant="primary")
238
 
239
+ with gr.Row():
240
+ taxon_url = gr.TextArea(label = "More Information",
241
+ elem_id = "textbox",
242
+ show_copy_button = True)
243
+
244
  with gr.Row():
245
  gr.Examples(
246
  examples=open_domain_examples,
247
  inputs=[img_input, rank_dropdown],
248
  cache_examples=True,
249
+ fn=lambda img, rank: open_domain_classification(img, rank, return_all=False),
250
  outputs=[open_domain_output],
251
  )
252
  '''
 
265
  )
266
  '''
267
  with gr.Tab("Zero-Shot"):
268
+ with gr.Row():
269
+ img_input_zs = gr.Image(height = 400, sources=["upload"])
270
+
271
  with gr.Row():
272
  with gr.Column():
273
  classes_txt = gr.Textbox(
 
288
  with gr.Row():
289
  gr.Examples(
290
  examples=zero_shot_examples,
291
+ inputs=[img_input_zs, classes_txt],
292
  cache_examples=True,
293
  fn=zero_shot_classification,
294
  outputs=[zero_shot_output],
 
311
  fn=change_output, inputs=rank_dropdown, outputs=[open_domain_output]
312
  )
313
 
314
+ # open_domain_btn.click(
315
+ # fn=open_domain_classification,
316
+ # inputs=[img_input, rank_dropdown],
317
+ # outputs=[open_domain_output, sample_img, taxon_url],
318
+ # )
319
+
320
  open_domain_btn.click(
321
+ fn=lambda img, rank: open_domain_classification(img, rank, return_all=True),
322
  inputs=[img_input, rank_dropdown],
323
+ outputs=[open_domain_output, sample_img, taxon_url],
324
  )
325
 
326
  zero_shot_btn.click(
327
  fn=zero_shot_classification,
328
+ inputs=[img_input_zs, classes_txt],
329
  outputs=zero_shot_output,
330
  )
331
+
332
+ # Footer to point out to model and data from app page.
333
+ gr.Markdown(
334
+ """
335
+ For more information on the [BioCLIP Model](https://huggingface.co/imageomics/bioclip) creation, see our [BioCLIP Project GitHub](https://github.com/Imageomics/bioclip), and
336
+ for easier integration of BioCLIP, checkout [pybioclip](https://github.com/Imageomics/pybioclip).
337
+
338
+ To learn more about the data, check out our [TreeOfLife-10M Dataset](https://huggingface.co/datasets/imageomics/TreeOfLife-10M).
339
+ """
340
+ )
341
 
342
  app.queue(max_size=20)
343
  app.launch()
components/query.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import boto3
3
+ import requests
4
+ import numpy as np
5
+ import polars as pl
6
+ from PIL import Image
7
+ from botocore.config import Config
8
+ import logging
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # S3 for sample images
13
+ my_config = Config(
14
+ region_name='us-east-1'
15
+ )
16
+ s3_client = boto3.client('s3', config=my_config)
17
+
18
+ # Set basepath for EOL pages for info
19
+ EOL_URL = "https://eol.org/pages/"
20
+ RANKS = ["kingdom", "phylum", "class", "order", "family", "genus", "species"]
21
+
22
+ def get_sample(df, pred_taxon, rank):
23
+ '''
24
+ Function to retrieve a sample image of the predicted taxon and EOL page link for more info.
25
+
26
+ Parameters:
27
+ -----------
28
+ df : DataFrame
29
+ DataFrame with all sample images listed and their filepaths (in "file_path" column).
30
+ pred_taxon : str
31
+ Predicted taxon of the uploaded image.
32
+ rank : int
33
+ Index of rank in RANKS chosen for prediction.
34
+
35
+ Returns:
36
+ --------
37
+ img : PIL.Image
38
+ Sample image of predicted taxon for display.
39
+ eol_page : str
40
+ URL to EOL page for the taxon (may be a lower rank, e.g., species sample).
41
+ '''
42
+ logger.info(f"Getting sample for taxon: {pred_taxon} at rank: {rank}")
43
+ try:
44
+ filepath, eol_page_id, full_name, is_exact = get_sample_data(df, pred_taxon, rank)
45
+ except Exception as e:
46
+ logger.error(f"Error retrieving sample data: {e}")
47
+ return None, f"We encountered the following error trying to retrieve a sample image: {e}."
48
+ if filepath is None:
49
+ logger.warning(f"No sample image found for taxon: {pred_taxon}")
50
+ return None, f"Sorry, our EOL images do not include {pred_taxon}."
51
+
52
+ # Get sample image of selected individual
53
+ try:
54
+ img_src = s3_client.generate_presigned_url('get_object',
55
+ Params={'Bucket': 'treeoflife-10m-sample-images',
56
+ 'Key': filepath}
57
+ )
58
+ img_resp = requests.get(img_src)
59
+ img = Image.open(io.BytesIO(img_resp.content))
60
+ if is_exact:
61
+ eol_page = f"Check out the EOL entry for {pred_taxon} to learn more: {EOL_URL}{eol_page_id}."
62
+ else:
63
+ eol_page = f"Check out an example EOL entry within {pred_taxon} to learn more: {full_name} {EOL_URL}{eol_page_id}."
64
+ logger.info(f"Successfully retrieved sample image and EOL page for {pred_taxon}")
65
+ return img, eol_page
66
+ except Exception as e:
67
+ logger.error(f"Error retrieving sample image: {e}")
68
+ return None, f"We encountered the following error trying to retrieve a sample image: {e}."
69
+
70
+ def get_sample_data(df, pred_taxon, rank):
71
+ '''
72
+ Function to randomly select a sample individual of the given taxon and provide associated native location.
73
+
74
+ Parameters:
75
+ -----------
76
+ df : DataFrame
77
+ DataFrame with all sample images listed and their filepaths (in "file_path" column).
78
+ pred_taxon : str
79
+ Predicted taxon of the uploaded image.
80
+ rank : int
81
+ Index of rank in RANKS chosen for prediction.
82
+
83
+ Returns:
84
+ --------
85
+ filepath : str
86
+ Filepath of selected sample image for predicted taxon.
87
+ eol_page_id : str
88
+ EOL page ID associated with predicted taxon for more information.
89
+ full_name : str
90
+ Full taxonomic name of the selected sample.
91
+ is_exact : bool
92
+ Flag indicating if the match is exact (i.e., with empty lower ranks).
93
+ '''
94
+ for idx in range(rank + 1):
95
+ taxon = RANKS[idx]
96
+ target_taxon = pred_taxon.split(" ")[idx]
97
+ df = df.filter(pl.col(taxon) == target_taxon)
98
+
99
+ if df.shape[0] == 0:
100
+ return None, np.nan, "", False
101
+
102
+ # First, try to find entries with empty lower ranks
103
+ exact_df = df
104
+ for lower_rank in RANKS[rank + 1:]:
105
+ exact_df = exact_df.filter((pl.col(lower_rank).is_null()) | (pl.col(lower_rank) == ""))
106
+
107
+ if exact_df.shape[0] > 0:
108
+ df_filtered = exact_df.sample()
109
+ full_name = " ".join(df_filtered.select(RANKS[:rank+1]).row(0))
110
+ return df_filtered["file_path"][0], df_filtered["eol_page_id"].cast(pl.String)[0], full_name, True
111
+
112
+ # If no exact matches, return any entry with the specified rank
113
+ df_filtered = df.sample()
114
+ full_name = " ".join(df_filtered.select(RANKS[:rank+1]).row(0)) + " " + " ".join(df_filtered.select(RANKS[rank+1:]).row(0))
115
+ return df_filtered["file_path"][0], df_filtered["eol_page_id"].cast(pl.String)[0], full_name, False
requirements.txt CHANGED
@@ -2,3 +2,6 @@ open_clip_torch
2
  torchvision
3
  torch
4
  gradio
 
 
 
 
2
  torchvision
3
  torch
4
  gradio
5
+ polars
6
+ pillow
7
+ boto3