ariG23498 HF staff commited on
Commit
0de6eb2
·
verified ·
1 Parent(s): 96bdaa7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -17
app.py CHANGED
@@ -1,31 +1,26 @@
1
- from huggingface_hub import list_models
2
  import gradio as gr
3
  from transformers import pipeline
4
 
5
- # Fetch timm models from Hugging Face Hub
6
- timm_models = list_models(filter="timm", sort="downloads", limit=20) # Fetch top 20 based on downloads
7
- model_ids = [model.modelId for model in timm_models]
8
-
9
- # Initialize a pipeline with a default model
10
- default_model = model_ids[0]
11
- pipe = pipeline("image-classification", model=default_model)
12
-
13
- # Function for classification
14
  def classify(image, model_name):
15
- pipe.model = model_name # Update model dynamically
16
- results = pipe(image)
17
- return {result["label"]: round(result["score"], 2) for result in results}
 
 
 
 
18
 
19
  # Gradio Interface
20
  demo = gr.Interface(
21
  fn=classify,
22
  inputs=[
23
  gr.Image(type="pil", label="Upload an Image"),
24
- gr.Dropdown(choices=model_ids, label="Select timm Model", value=default_model)
25
  ],
26
  outputs=gr.Label(num_top_classes=3, label="Top Predictions"),
27
- title="timm Model Image Classifier",
28
- description="Select a timm model and upload an image for classification."
29
  )
30
 
31
- demo.launch(debug=True)
 
 
1
  import gradio as gr
2
  from transformers import pipeline
3
 
4
+ # Function for image classification
 
 
 
 
 
 
 
 
5
  def classify(image, model_name):
6
+ try:
7
+ pipe = pipeline("image-classification", model=model_name)
8
+ results = pipe(image)
9
+ return {result["label"]: round(result["score"], 2) for result in results}
10
+ except Exception as e:
11
+ # Handle errors gracefully, e.g., invalid model names
12
+ return {"Error": str(e)}
13
 
14
  # Gradio Interface
15
  demo = gr.Interface(
16
  fn=classify,
17
  inputs=[
18
  gr.Image(type="pil", label="Upload an Image"),
19
+ gr.Textbox(label="Enter timm Model Name", placeholder="e.g., timm/mobilenetv3_large_100.ra_in1k"),
20
  ],
21
  outputs=gr.Label(num_top_classes=3, label="Top Predictions"),
22
+ title="Custom timm Model Image Classifier",
23
+ description="Enter a timm model name from Hugging Face, upload an image, and get predictions.",
24
  )
25
 
26
+ demo.launch()