ravi.naik commited on
Commit
f94c291
1 Parent(s): 4822531

Fixed sample image path issues

Browse files
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import random
 
3
  import numpy as np
4
  from PIL import Image
5
  import torch
@@ -39,16 +40,23 @@ def read_image(path):
39
  return data
40
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def sample_images():
43
- images = []
44
- length = len(datamodule.test_dataset)
45
- classes = datamodule.train_dataset.classes
46
- for i in range(10):
47
- idx = random.randint(0, length - 1)
48
- image, label = datamodule.test_dataset[idx]
49
- image = inv_normalize(image).permute(1, 2, 0).numpy()
50
- images.append((image, classes[label]))
51
- return images
52
 
53
 
54
  def get_misclassified_images(misclassified_count):
 
1
  import gradio as gr
2
  import random
3
+ import pathlib
4
  import numpy as np
5
  from PIL import Image
6
  import torch
 
40
  return data
41
 
42
 
43
+ # def sample_images():
44
+ # images = []
45
+ # length = len(datamodule.test_dataset)
46
+ # classes = datamodule.train_dataset.classes
47
+ # for i in range(10):
48
+ # idx = random.randint(0, length - 1)
49
+ # image, label = datamodule.test_dataset[idx]
50
+ # image = inv_normalize(image).permute(1, 2, 0).numpy()
51
+ # images.append((image, classes[label]))
52
+ # return images
53
+
54
+
55
  def sample_images():
56
+ sample_imges_dir = pathlib.Path("./sample_images")
57
+ sample_images = list(sample_imges_dir.iterdir())
58
+ sample_image_labels = [image.stem for image in sample_images]
59
+ return list(zip(sample_images, sample_image_labels))
 
 
 
 
 
60
 
61
 
62
  def get_misclassified_images(misclassified_count):
sample_images/airplane.png ADDED
sample_images/automobile.png ADDED
sample_images/bird.png ADDED
sample_images/cat.png ADDED
sample_images/deer.png ADDED
sample_images/dog.png ADDED
sample_images/frog.png ADDED
sample_images/horse.png ADDED
sample_images/ship.png ADDED
sample_images/truck.png ADDED