miccull commited on
Commit
4602ab6
·
1 Parent(s): e0ce7b8

initial commit

Browse files
Files changed (2) hide show
  1. app.py +114 -0
  2. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from PIL import Image
4
+ import torch
5
+ import torchvision
6
+ import clip
7
+ import matplotlib.pyplot as plt
8
+ import seaborn as sns
9
+ import gradio as gr
10
+
11
+
12
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
13
+
14
+ model_name = 'ViT-B/16' #@param ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16']
15
+ model, preprocess = clip.load(model_name)
16
+
17
+ model.to(DEVICE).eval()
18
+ resolution = model.visual.input_resolution
19
+ resizer = torchvision.transforms.Resize(size=(resolution, resolution))
20
+
21
+
22
+ def create_rgb_tensor(color):
23
+ """color is e.g. [1,0,0]"""
24
+ return torch.tensor(color, device=DEVICE).reshape((1, 3, 1, 1))
25
+
26
+ def encode_color(color):
27
+ """color is e.g. [1,0,0]"""
28
+ rgb = create_rgb_tensor(color)
29
+ return model.encode_image( resizer(rgb) )
30
+
31
+ def encode_text(text):
32
+ tokenized_text = clip.tokenize(text).cuda()
33
+ return model.encode_text(tokenized_text)
34
+
35
+ def lerp(x, y, steps=11):
36
+ """Linear interpolation between two tensors """
37
+
38
+ weights = torch.tensor(np.linspace(0,1,steps), device=DEVICE).reshape([-1, 1, 1, 1])
39
+
40
+ interpolated = x * (1 - weights) + y * weights
41
+
42
+ return interpolated
43
+
44
+ def get_interpolated_scores(x, y, encoded_text, steps=11):
45
+ interpolated = lerp(x, y, steps)
46
+ interpolated_encodings = model.encode_image(resizer(interpolated))
47
+
48
+ scores = torch.cosine_similarity(interpolated_encodings, encoded_text).detach().cpu().numpy()
49
+
50
+ rgb = interpolated.detach().cpu().numpy().reshape(-1, 3)
51
+ interpolated_hex = [rgb2hex(x) for x in rgb]
52
+
53
+ data = pd.DataFrame({
54
+ 'similarity': scores,
55
+ 'color': interpolated_hex
56
+ }).reset_index().rename(columns={'index':'step'})
57
+
58
+ return data
59
+
60
+ def rgb2hex(rgb):
61
+ rgb = (rgb * 255).astype(int)
62
+ r,g,b = rgb
63
+ return "#{:02x}{:02x}{:02x}".format(r,g,b)
64
+
65
+
66
+ def similarity_plot(data, text_prompt):
67
+ title = f'CLIP Cosine Similarity Prompt="{text_prompt}"'
68
+
69
+ fig, ax = plt.subplots()
70
+ plot = data['similarity'].plot(kind='bar',
71
+ ax=ax,
72
+ stacked=True,
73
+ title=title,
74
+ color=data['color'],
75
+ width=1.0,
76
+ xlim=(0, 2),
77
+ grid=False)
78
+
79
+
80
+ plot.get_xaxis().set_visible(False) ;
81
+ return fig
82
+
83
+
84
+
85
+ def interpolation_experiment(rgb_start, rgb_end, text_prompt, steps=11):
86
+
87
+ start = create_rgb_tensor(rgb_start)
88
+ end = create_rgb_tensor(rgb_end)
89
+ encoded_text = encode_text(text_prompt)
90
+
91
+ data = get_interpolated_scores(start, end, encoded_text, steps)
92
+ return similarity_plot(data, text_prompt)
93
+
94
+
95
+
96
+
97
+ start_input = gr.inputs.Textbox(lines=1, default="1, 0, 0", label="Start RGB")
98
+ end_input = gr.inputs.Textbox(lines=1, default="0, 1, 0", label="End RGB")
99
+ ' (Comma separated numbers between 0 and 1)'
100
+
101
+ text_input = gr.inputs.Textbox(lines=1, label="Text Prompt", default='A solid red square')
102
+
103
+ steps_input = gr.inputs.Slider(minimum=1, maximum=100, step=1, default=11, label="Interpolation Steps")
104
+
105
+ def gradio_fn(rgb_start, rgb_end, text_prompt, steps=11):
106
+
107
+ rgb_start = [float(x.strip()) for x in rgb_start.split(',')]
108
+ rgb_end = [float(x.strip()) for x in rgb_end.split(',')]
109
+ out = interpolation_experiment(rgb_start, rgb_end, text_prompt, steps)
110
+
111
+ return out
112
+
113
+ iface = gr.Interface( fn=gradio_fn, inputs=[start_input, end_input, text_input, steps_input], outputs="plot")
114
+ iface.launch(debug=True, share=False)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ comet_ml
2
+ ftfy
3
+ regex
4
+ git+https://github.com/openai/CLIP.git
5
+ pandas
6
+ Pillow
7
+ tqdm
8
+ torch
9
+ torchvision
10
+ matplotlib
11
+ seaborn