Spaces:
Runtime error
Runtime error
import torch | |
import clip | |
from PIL import Image | |
import glob | |
import os | |
from random import choice | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model, preprocess = clip.load("ViT-L/14@336px", device=device) | |
COCO = glob.glob(os.path.join(os.getcwd(), "images", "*")) | |
available_models = ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'] | |
def load_random_image(): | |
image_path = choice(COCO) | |
image = Image.open(image_path) | |
return image | |
def next_image(): | |
global image_org, image | |
image_org = load_random_image() | |
image = preprocess(Image.fromarray(image_org)).unsqueeze(0).to(device) | |
# def calculate_logits(image, text): | |
# return model(image, text)[0] | |
def calculate_logits(image_features, text_features): | |
image_features = image_features / image_features.norm(dim=1, keepdim=True) | |
text_features = text_features / text_features.norm(dim=1, keepdim=True) | |
logit_scale = model.logit_scale.exp() | |
return logit_scale * image_features @ text_features.t() | |
last = -1 | |
best = -1 | |
goal = 23 | |
image_org = load_random_image() | |
image = preprocess(image_org).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
image_features = model.encode_image(image) | |
def answer(message): | |
global last, best | |
text = clip.tokenize([message]).to(device) | |
with torch.no_grad(): | |
text_features = model.encode_text(text) | |
# logits_per_image, _ = model(image, text) | |
logits = calculate_logits(image_features, text_features).cpu().numpy().flatten()[0] | |
# logits = calculate_logits(image, text) | |
if last == -1: | |
is_better = -1 | |
elif last > logits: | |
is_better = 0 | |
elif last < logits: | |
is_better = 1 | |
elif logits > goal: | |
is_better = 2 | |
else: | |
is_better = -1 | |
last = logits | |
if logits > best: | |
best = logits | |
is_better = 3 | |
return logits, is_better | |
def reset_everything(): | |
global last, best, goal, image, image_org | |
last = -1 | |
best = -1 | |
goal = 23 | |
image_org = load_random_image() | |
image = preprocess(image_org).unsqueeze(0).to(device) | |