File size: 2,738 Bytes
c45703e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import pandas as pd
import torch
import faiss
import gradio as gr

import base64

from PIL import Image
from io import BytesIO

from src.model import ConditionalViT, B16_Params, categories
from src.transform import valid_tf
from src.process_images import process_img, make_img_html
from src.examples import ExamplesHandler
from src.js_loader import JavaScriptLoader

# Load Model
m = ConditionalViT(**B16_Params, n_categories=len(categories))
m.load_state_dict(torch.load("./artifacts/cat_condvit_b16.pth", map_location="cpu"))
m.eval()

# Load data
index = faiss.read_index("./artifacts/gallery_index.faiss")
gal_imgs = pd.read_parquet("./artifacts/gallery_imgs.parquet")
tfs = valid_tf((224, 224))

K = 5

examples = [
    ["examples/3.jpg", "Outwear"],
    ["examples/3.jpg", "Lower Body"],
    ["examples/3.jpg", "Feet"],
    ["examples/757.jpg", "Bags"],
    ["examples/757.jpg", "Upper Body"],
    ["examples/769.jpg", "Upper Body"],
    ["examples/1811.jpg", "Lower Body"],
    ["examples/1811.jpg", "Bags"],
]

@torch.inference_mode()
def retrieval(image, category):
    if image is None or category is None: return

    q_emb = m(tfs(image).unsqueeze(0), torch.tensor([category]))

    r = index.search(q_emb, K)

    imgs = [process_img(idx, gal_imgs) for idx in r[1][0]]

    html = [make_img_html(i) for i in imgs]
    html += ["<p></p>"] # Avoid Gradio's last-child{margin-bottom:0!important;}

    return "\n".join(html)



JavaScriptLoader("src/custom_functions.js")
with gr.Blocks(css="src/style.css") as demo:
    with gr.Column():
        gr.Markdown("""
        # Conditional ViT Demo 
        [[`Paper`](https://arxiv.org/abs/2306.02928)] 
        [[`Code`](https://github.com/Simon-Lepage/CondViT-LRVSF)] 
        [[`Dataset`](https://huggingface.co/datasets/Slep/LAION-RVS-Fashion)] 

        *Running on 2 vCPU, 16Go RAM.*

        - **Model :** Categorical CondViT-B/16
        - **Gallery :** 93K images.
        """)

        # Input section
        with gr.Row():
            img = gr.Image(label="Query Image", type="pil", elem_id="query_img")
            with gr.Column():
                cat = gr.Dropdown(choices = categories, label="Category", value="Upper Body", type='index', elem_id="dropdown")
                submit = gr.Button("Submit")
        
        # Examples
        gr.Examples(examples, inputs=[img, cat], fn=retrieval, elem_id = "preset_examples", examples_per_page=100)
        gr.HTML(value=ExamplesHandler(examples).to_html(), label = "examples", elem_id = "html_examples")
        
        # Outputs
        gr.Markdown("# Retrieved Items")
        out = gr.HTML(label="Results", elem_id = "html_output")

        submit.click(fn=retrieval, inputs=[img, cat], outputs=out)

demo.launch()