|
import PIL |
|
import torch |
|
import gradio as gr |
|
from process import load_seg_model, get_palette, generate_mask |
|
|
|
|
|
device = 'cpu' |
|
|
|
|
|
|
|
def initialize_and_load_models(): |
|
|
|
checkpoint_path = 'model/cloth_segm.pth' |
|
net = load_seg_model(checkpoint_path, device=device) |
|
|
|
return net |
|
|
|
net = initialize_and_load_models() |
|
palette = get_palette(4) |
|
|
|
|
|
def run(img): |
|
|
|
cloth_seg = generate_mask(img, net=net, palette=palette, device=device) |
|
return cloth_seg |
|
|
|
|
|
input_image = gr.inputs.Image(label="Input Image", type="pil") |
|
|
|
|
|
cloth_seg_image = gr.outputs.Image(label="Cloth Segmentation", type="pil") |
|
|
|
title = "Demo for Cloth Segmentation" |
|
description = "An app for Cloth Segmentation" |
|
inputs = [input_image] |
|
outputs = [cloth_seg_image] |
|
|
|
|
|
gr.Interface(fn=run, inputs=inputs, outputs=outputs, title=title, description=description).launch(share=True) |
|
|