seojun commited on
Commit
fcdbac9
1 Parent(s): 04b2c5f

FEAT : 작업3 완료

Browse files
Files changed (8) hide show
  1. 01.jpg +0 -0
  2. 02.jpeg +0 -0
  3. 03.jpeg +0 -0
  4. 04.jpeg +0 -0
  5. README.md +4 -4
  6. app.py +182 -0
  7. labels.txt +19 -0
  8. requirements.txt +7 -0
01.jpg ADDED
02.jpeg ADDED
03.jpeg ADDED
04.jpeg ADDED
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Segmentation3
3
- emoji: 📚
4
- colorFrom: blue
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 4.2.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
+ title: Segment3
3
+ emoji: 🌖
4
+ colorFrom: green
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 3.44.4
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ from matplotlib import gridspec
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation
7
+ from transformers import DetrImageProcessor, DetrForObjectDetection
8
+ import torch
9
+ import tensorflow as tf
10
+ from PIL import ImageDraw
11
+
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ # image segmentation 모델
15
+ feature_extractor = SegformerFeatureExtractor.from_pretrained(
16
+ "nvidia/segformer-b1-finetuned-cityscapes-1024-1024"
17
+ )
18
+ model_segmentation = TFSegformerForSemanticSegmentation.from_pretrained(
19
+ "nvidia/segformer-b1-finetuned-cityscapes-1024-1024"
20
+ )
21
+
22
+ # image detection 모델
23
+ processor_detection = DetrImageProcessor.from_pretrained(
24
+ "facebook/detr-resnet-50", revision="no_timm"
25
+ )
26
+ model_detection = DetrForObjectDetection.from_pretrained(
27
+ "facebook/detr-resnet-50", revision="no_timm"
28
+ )
29
+
30
+
31
+ def ade_palette():
32
+ """ADE20K 팔레트: 각 클래스를 RGB 값에 매핑해주는 함수입니다."""
33
+
34
+ return [
35
+ [204, 87, 92],
36
+ [112, 185, 212],
37
+ [45, 189, 106],
38
+ [234, 123, 67],
39
+ [78, 56, 123],
40
+ [210, 32, 89],
41
+ [90, 180, 56],
42
+ [155, 102, 200],
43
+ [33, 147, 176],
44
+ [255, 183, 76],
45
+ [67, 123, 89],
46
+ [190, 60, 45],
47
+ [134, 112, 200],
48
+ [56, 45, 189],
49
+ [200, 56, 123],
50
+ [87, 92, 204],
51
+ [120, 56, 123],
52
+ [45, 78, 123],
53
+ [45, 123, 67],
54
+ ]
55
+
56
+
57
+ labels_list = []
58
+
59
+ with open(r"labels.txt", "r") as fp:
60
+ for line in fp:
61
+ labels_list.append(line[:-1])
62
+
63
+ colormap = np.asarray(ade_palette())
64
+
65
+
66
+ def label_to_color_image(label):
67
+ """라벨을 컬러 이미지로 변환해주는 함수입니다."""
68
+
69
+ if label.ndim != 2:
70
+ raise ValueError("2차원 입력 라벨을 기대합니다.")
71
+
72
+ if np.max(label) >= len(colormap):
73
+ raise ValueError("라벨 값이 너무 큽니다.")
74
+ return colormap[label]
75
+
76
+
77
+ def draw_plot(pred_img, seg):
78
+ """이미지와 세그멘테이션 결과를 floating 하는 함수입니다."""
79
+
80
+ fig = plt.figure(figsize=(20, 15))
81
+ grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
82
+
83
+ plt.subplot(grid_spec[0])
84
+ plt.imshow(pred_img)
85
+ plt.axis("off")
86
+ LABEL_NAMES = np.asarray(labels_list)
87
+ FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
88
+ FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
89
+
90
+ unique_labels = np.unique(seg.numpy().astype("uint8"))
91
+ ax = plt.subplot(grid_spec[1])
92
+ plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
93
+ ax.yaxis.tick_right()
94
+ plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
95
+ plt.xticks([], [])
96
+ ax.tick_params(width=0.0, labelsize=25)
97
+
98
+ return fig
99
+
100
+
101
+ def sepia(inputs, button_text):
102
+ """객체 검출 또는 세그멘테이션을 수행하고 결과를 반환하는 함수입니다."""
103
+
104
+ input_img = Image.fromarray(inputs)
105
+ if button_text == "detection":
106
+ inputs_detection = processor_detection(images=input_img, return_tensors="pt")
107
+ outputs_detection = model_detection(**inputs_detection)
108
+
109
+ target_sizes = torch.tensor([input_img.size[::-1]])
110
+ results_detection = processor_detection.post_process_object_detection(
111
+ outputs_detection, target_sizes=target_sizes, threshold=0.9
112
+ )[0]
113
+
114
+ draw = ImageDraw.Draw(input_img)
115
+ for score, label, box in zip(
116
+ results_detection["scores"],
117
+ results_detection["labels"],
118
+ results_detection["boxes"],
119
+ ):
120
+ box = [round(i, 2) for i in box.tolist()]
121
+ label_name = model_detection.config.id2label[label.item()]
122
+ print(
123
+ f"Detected {label_name} with confidence "
124
+ f"{round(score.item(), 3)} at location {box}"
125
+ )
126
+ draw.rectangle(box, outline="red", width=3)
127
+ draw.text((box[0], box[1]), label_name, fill="red", font=None)
128
+
129
+ fig = plt.figure(figsize=(20, 15))
130
+ plt.imshow(input_img)
131
+ plt.axis("off")
132
+ return fig
133
+
134
+ elif button_text == "segmentation":
135
+ inputs_segmentation = feature_extractor(images=input_img, return_tensors="tf")
136
+ outputs_segmentation = model_segmentation(**inputs_segmentation)
137
+ logits_segmentation = outputs_segmentation.logits
138
+
139
+ logits_segmentation = tf.transpose(logits_segmentation, [0, 2, 3, 1])
140
+ logits_segmentation = tf.image.resize(logits_segmentation, input_img.size[::-1])
141
+ seg = tf.math.argmax(logits_segmentation, axis=-1)[0]
142
+
143
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
144
+ for label, color in enumerate(colormap):
145
+ color_seg[seg.numpy() == label, :] = color
146
+
147
+ pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
148
+ pred_img = pred_img.astype(np.uint8)
149
+
150
+ fig = draw_plot(pred_img, seg)
151
+ return fig
152
+
153
+ return "Please select 'detection' or 'segmentation'."
154
+
155
+ def on_button_click(inputs):
156
+ """버튼 클릭 이벤트 핸들러"""
157
+ image_path, selected_option = inputs
158
+ if selected_option == "dropout":
159
+ # 'dropout'이면 두 가지 중에 하나를 랜덤으로 선택
160
+ selected_option = np.random.choice(["detection", "segmentation"])
161
+
162
+ return sepia(image_path, selected_option)
163
+
164
+ # Gr.Dropdown을 사용하여 옵션을 선택할 수 있도록 변경
165
+ dropdown = gr.Dropdown(
166
+ ["detection", "segmentation"], label="Menu", info="Select One!"
167
+ )
168
+
169
+ demo = gr.Interface(
170
+ fn=sepia,
171
+ inputs=[gr.Image(shape=(400, 600)), dropdown],
172
+ outputs=["plot"],
173
+ examples=[
174
+ ["01.jpg", "Click me"],
175
+ ["02.jpeg", "Click me"],
176
+ ["03.jpeg", "Click me"],
177
+ ["04.jpeg", "Click me"],
178
+ ],
179
+ allow_flagging="never",
180
+ )
181
+
182
+ demo.launch()
labels.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ road
2
+ sidewalk
3
+ building
4
+ wall
5
+ fence
6
+ pole
7
+ traffic light
8
+ traffic sign
9
+ vegetation
10
+ terrain
11
+ sky
12
+ person
13
+ rider
14
+ car
15
+ truck
16
+ bus
17
+ train
18
+ motorcycle
19
+ bicycle
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ tensorflow==2.13.0
4
+ numpy
5
+ Image
6
+ matplotlib
7
+ Pillow