import pandas as pd import torch import faiss import gradio as gr import base64 from PIL import Image from io import BytesIO from src.model import ConditionalViT, B16_Params, categories from src.transform import valid_tf from src.process_images import process_img, make_img_html from src.examples import ExamplesHandler from src.js_loader import JavaScriptLoader # Load Model m = ConditionalViT(**B16_Params, n_categories=len(categories)) m.load_state_dict(torch.load("./artifacts/cat_condvit_b16.pth", map_location="cpu")) m.eval() # Load data index = faiss.read_index("./artifacts/gallery_index.faiss") gal_imgs = pd.read_parquet("./artifacts/gallery_imgs.parquet") tfs = valid_tf((224, 224)) K = 5 examples = [ ["examples/3.jpg", "Outwear"], ["examples/3.jpg", "Lower Body"], ["examples/3.jpg", "Feet"], ["examples/757.jpg", "Bags"], ["examples/757.jpg", "Upper Body"], ["examples/769.jpg", "Upper Body"], ["examples/1811.jpg", "Lower Body"], ["examples/1811.jpg", "Bags"], ] @torch.inference_mode() def retrieval(image, category): if image is None or category is None: return q_emb = m(tfs(image).unsqueeze(0), torch.tensor([category])) r = index.search(q_emb, K) imgs = [process_img(idx, gal_imgs) for idx in r[1][0]] html = [make_img_html(i) for i in imgs] html += ["
"] # Avoid Gradio's last-child{margin-bottom:0!important;} return "\n".join(html) JavaScriptLoader("src/custom_functions.js") with gr.Blocks(css="src/style.css") as demo: with gr.Column(): gr.Markdown(""" # Conditional ViT Demo [[`Paper`](https://arxiv.org/abs/2306.02928)] [[`Code`](https://github.com/Simon-Lepage/CondViT-LRVSF)] [[`Dataset`](https://huggingface.co/datasets/Slep/LAION-RVS-Fashion)] *Running on 2 vCPU, 16Go RAM.* - **Model :** Categorical CondViT-B/16 - **Gallery :** 93K images. """) # Input section with gr.Row(): img = gr.Image(label="Query Image", type="pil", elem_id="query_img") with gr.Column(): cat = gr.Dropdown(choices = categories, label="Category", value="Upper Body", type='index', elem_id="dropdown") submit = gr.Button("Submit") # Examples gr.Examples(examples, inputs=[img, cat], fn=retrieval, elem_id = "preset_examples", examples_per_page=100) gr.HTML(value=ExamplesHandler(examples).to_html(), label = "examples", elem_id = "html_examples") # Outputs gr.Markdown("# Retrieved Items") out = gr.HTML(label="Results", elem_id = "html_output") submit.click(fn=retrieval, inputs=[img, cat], outputs=out) demo.launch()