import gradio as gr from PIL import Image from matplotlib import gridspec import matplotlib.pyplot as plt import numpy as np from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation from transformers import DetrImageProcessor, DetrForObjectDetection import torch import tensorflow as tf from PIL import ImageDraw device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # image segmentation 모델 feature_extractor = SegformerFeatureExtractor.from_pretrained( "nvidia/segformer-b1-finetuned-cityscapes-1024-1024" ) model_segmentation = TFSegformerForSemanticSegmentation.from_pretrained( "nvidia/segformer-b1-finetuned-cityscapes-1024-1024" ) # image detection 모델 # processor_detection = DetrImageProcessor.from_pretrained( # "facebook/detr-resnet-50", revision="no_timm" # ) # model_detection = DetrForObjectDetection.from_pretrained( # "facebook/detr-resnet-50", revision="no_timm" # ) def ade_palette(): """ADE20K 팔레트: 각 클래스를 RGB 값에 매핑해주는 함수입니다.""" return [ [204, 87, 92], [112, 185, 212], [45, 189, 106], [234, 123, 67], [78, 56, 123], [210, 32, 89], [90, 180, 56], [155, 102, 200], [33, 147, 176], [255, 183, 76], [67, 123, 89], [190, 60, 45], [134, 112, 200], [56, 45, 189], [200, 56, 123], [87, 92, 204], [120, 56, 123], [45, 78, 123], [45, 123, 67], ] labels_list = [] with open(r"labels.txt", "r") as fp: for line in fp: labels_list.append(line[:-1]) colormap = np.asarray(ade_palette()) def label_to_color_image(label): """라벨을 컬러 이미지로 변환해주는 함수입니다.""" if label.ndim != 2: raise ValueError("2차원 입력 라벨을 기대합니다.") if np.max(label) >= len(colormap): raise ValueError("라벨 값이 너무 큽니다.") return colormap[label] def draw_plot(pred_img, seg): """이미지와 세그멘테이션 결과를 floating 하는 함수입니다.""" fig = plt.figure(figsize=(20, 15)) grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1]) plt.subplot(grid_spec[0]) plt.imshow(pred_img) plt.axis("off") LABEL_NAMES = np.asarray(labels_list) FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1) FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP) unique_labels = np.unique(seg.numpy().astype("uint8")) ax = plt.subplot(grid_spec[1]) plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest") ax.yaxis.tick_right() plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels]) plt.xticks([], []) ax.tick_params(width=0.0, labelsize=25) return fig def sepia(inputs, button_text): """객체 검출 또는 세그멘테이션을 수행하고 결과를 반환하는 함수입니다.""" input_img = Image.fromarray(inputs) inputs_segmentation = feature_extractor(images=input_img, return_tensors="tf") outputs_segmentation = model_segmentation(**inputs_segmentation) logits_segmentation = outputs_segmentation.logits logits_segmentation = tf.transpose(logits_segmentation, [0, 2, 3, 1]) logits_segmentation = tf.image.resize(logits_segmentation, input_img.size[::-1]) seg = tf.math.argmax(logits_segmentation, axis=-1)[0] color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) for label, color in enumerate(colormap): color_seg[seg.numpy() == label, :] = color pred_img = np.array(input_img) * 0.5 + color_seg * 0.5 pred_img = pred_img.astype(np.uint8) fig = draw_plot(pred_img, seg) return fig def on_button_click(inputs): """버튼 클릭 이벤트 핸들러""" image_path, selected_option = inputs if selected_option == "dropout": # 'dropout'이면 두 가지 중에 하나를 랜덤으로 선택 selected_option = np.random.choice(["segmentation"]) return sepia(image_path, selected_option) # Gr.Dropdown을 사용하여 옵션을 선택할 수 있도록 변경 dropdown = gr.Dropdown( ["segmentation"], label="Menu", info="Chose Segmentation!" ) demo = gr.Interface(fn=sepia, inputs=[gr.Image(shape=(400, 600)), dropdown], outputs=["plot"], examples= [ ["01.jpg", "1"], ["02.jpeg", "2"], ["03.jpeg", "3"], ["04.jpeg", "4"], ], allow_flagging="never",) demo.launch()