Spaces:
Build error
Build error
from argparse import ArgumentParser | |
from pathlib import Path | |
from typing import Dict, List, Optional, TextIO, Tuple | |
import torch | |
from PIL import Image, UnidentifiedImageError | |
from torch import Tensor | |
from torch.nn import Module, Parameter | |
from torch.nn.functional import relu, sigmoid | |
from torch.utils.data import DataLoader, Dataset | |
from tqdm import tqdm | |
from ram import get_transform | |
from ram.models import ram, tag2text | |
from ram.utils import build_openset_label_embedding, get_mAP, get_PR | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def parse_args(): | |
parser = ArgumentParser() | |
# model | |
parser.add_argument("--model-type", | |
type=str, | |
choices=("ram", "tag2text"), | |
required=True) | |
parser.add_argument("--checkpoint", | |
type=str, | |
required=True) | |
parser.add_argument("--backbone", | |
type=str, | |
choices=("swin_l", "swin_b"), | |
default=None, | |
help="If `None`, will judge from `--model-type`") | |
parser.add_argument("--open-set", | |
action="store_true", | |
help=( | |
"Treat all categories in the taglist file as " | |
"unseen and perform open-set classification. Only " | |
"works with RAM." | |
)) | |
# data | |
parser.add_argument("--dataset", | |
type=str, | |
choices=( | |
"openimages_common_214", | |
"openimages_rare_200" | |
), | |
required=True) | |
parser.add_argument("--input-size", | |
type=int, | |
default=384) | |
# threshold | |
group = parser.add_mutually_exclusive_group() | |
group.add_argument("--threshold", | |
type=float, | |
default=None, | |
help=( | |
"Use custom threshold for all classes. Mutually " | |
"exclusive with `--threshold-file`. If both " | |
"`--threshold` and `--threshold-file` is `None`, " | |
"will use a default threshold setting." | |
)) | |
group.add_argument("--threshold-file", | |
type=str, | |
default=None, | |
help=( | |
"Use custom class-wise thresholds by providing a " | |
"text file. Each line is a float-type threshold, " | |
"following the order of the tags in taglist file. " | |
"See `ram/data/ram_tag_list_threshold.txt` as an " | |
"example. Mutually exclusive with `--threshold`. " | |
"If both `--threshold` and `--threshold-file` is " | |
"`None`, will use default threshold setting." | |
)) | |
# miscellaneous | |
parser.add_argument("--output-dir", type=str, default="./outputs") | |
parser.add_argument("--batch-size", type=int, default=128) | |
parser.add_argument("--num-workers", type=int, default=4) | |
args = parser.parse_args() | |
# post process and validity check | |
args.model_type = args.model_type.lower() | |
assert not (args.model_type == "tag2text" and args.open_set) | |
if args.backbone is None: | |
args.backbone = "swin_l" if args.model_type == "ram" else "swin_b" | |
return args | |
def load_dataset( | |
dataset: str, | |
model_type: str, | |
input_size: int, | |
batch_size: int, | |
num_workers: int | |
) -> Tuple[DataLoader, Dict]: | |
dataset_root = str(Path(__file__).resolve().parent / "datasets" / dataset) | |
img_root = dataset_root + "/imgs" | |
# Label system of tag2text contains duplicate tag texts, like | |
# "train" (noun) and "train" (verb). Therefore, for tag2text, we use | |
# `tagid` instead of `tag`. | |
if model_type == "ram": | |
tag_file = dataset_root + f"/{dataset}_ram_taglist.txt" | |
annot_file = dataset_root + f"/{dataset}_{model_type}_annots.txt" | |
else: | |
tag_file = dataset_root + f"/{dataset}_tag2text_tagidlist.txt" | |
annot_file = dataset_root + f"/{dataset}_{model_type}_idannots.txt" | |
with open(tag_file, "r", encoding="utf-8") as f: | |
taglist = [line.strip() for line in f] | |
with open(annot_file, "r", encoding="utf-8") as f: | |
imglist = [img_root + "/" + line.strip().split(",")[0] for line in f] | |
class _Dataset(Dataset): | |
def __init__(self): | |
self.transform = get_transform(input_size) | |
def __len__(self): | |
return len(imglist) | |
def __getitem__(self, index): | |
try: | |
img = Image.open(imglist[index]) | |
except (OSError, FileNotFoundError, UnidentifiedImageError): | |
img = Image.new('RGB', (10, 10), 0) | |
print("Error loading image:", imglist[index]) | |
return self.transform(img) | |
loader = DataLoader( | |
dataset=_Dataset(), | |
shuffle=False, | |
drop_last=False, | |
pin_memory=True, | |
batch_size=batch_size, | |
num_workers=num_workers | |
) | |
info = { | |
"taglist": taglist, | |
"imglist": imglist, | |
"annot_file": annot_file, | |
"img_root": img_root | |
} | |
return loader, info | |
def get_class_idxs( | |
model_type: str, | |
open_set: bool, | |
taglist: List[str] | |
) -> Optional[List[int]]: | |
"""Get indices of required categories in the label system.""" | |
if model_type == "ram": | |
if not open_set: | |
model_taglist_file = "ram/data/ram_tag_list.txt" | |
with open(model_taglist_file, "r", encoding="utf-8") as f: | |
model_taglist = [line.strip() for line in f] | |
return [model_taglist.index(tag) for tag in taglist] | |
else: | |
return None | |
else: # for tag2text, we directly use tagid instead of text-form of tag. | |
# here tagid equals to tag index. | |
return [int(tag) for tag in taglist] | |
def load_thresholds( | |
threshold: Optional[float], | |
threshold_file: Optional[str], | |
model_type: str, | |
open_set: bool, | |
class_idxs: List[int], | |
num_classes: int, | |
) -> List[float]: | |
"""Decide what threshold(s) to use.""" | |
if not threshold_file and not threshold: # use default | |
if model_type == "ram": | |
if not open_set: # use class-wise tuned thresholds | |
ram_threshold_file = "ram/data/ram_tag_list_threshold.txt" | |
with open(ram_threshold_file, "r", encoding="utf-8") as f: | |
idx2thre = { | |
idx: float(line.strip()) for idx, line in enumerate(f) | |
} | |
return [idx2thre[idx] for idx in class_idxs] | |
else: | |
return [0.5] * num_classes | |
else: | |
return [0.68] * num_classes | |
elif threshold_file: | |
with open(threshold_file, "r", encoding="utf-8") as f: | |
thresholds = [float(line.strip()) for line in f] | |
assert len(thresholds) == num_classes | |
return thresholds | |
else: | |
return [threshold] * num_classes | |
def gen_pred_file( | |
imglist: List[str], | |
tags: List[List[str]], | |
img_root: str, | |
pred_file: str | |
) -> None: | |
"""Generate text file of tag prediction results.""" | |
with open(pred_file, "w", encoding="utf-8") as f: | |
for image, tag in zip(imglist, tags): | |
# should be relative to img_root to match the gt file. | |
s = str(Path(image).relative_to(img_root)) | |
if tag: | |
s = s + "," + ",".join(tag) | |
f.write(s + "\n") | |
def load_ram( | |
backbone: str, | |
checkpoint: str, | |
input_size: int, | |
taglist: List[str], | |
open_set: bool, | |
class_idxs: List[int], | |
) -> Module: | |
model = ram(pretrained=checkpoint, image_size=input_size, vit=backbone) | |
# trim taglist for faster inference | |
if open_set: | |
print("Building tag embeddings ...") | |
label_embed, _ = build_openset_label_embedding(taglist) | |
model.label_embed = Parameter(label_embed.float()) | |
else: | |
model.label_embed = Parameter(model.label_embed[class_idxs, :]) | |
return model.to(device).eval() | |
def load_tag2text( | |
backbone: str, | |
checkpoint: str, | |
input_size: int | |
) -> Module: | |
model = tag2text( | |
pretrained=checkpoint, | |
image_size=input_size, | |
vit=backbone | |
) | |
return model.to(device).eval() | |
def forward_ram(model: Module, imgs: Tensor) -> Tensor: | |
image_embeds = model.image_proj(model.visual_encoder(imgs.to(device))) | |
image_atts = torch.ones( | |
image_embeds.size()[:-1], dtype=torch.long).to(device) | |
label_embed = relu(model.wordvec_proj(model.label_embed)).unsqueeze(0)\ | |
.repeat(imgs.shape[0], 1, 1) | |
tagging_embed, _ = model.tagging_head( | |
encoder_embeds=label_embed, | |
encoder_hidden_states=image_embeds, | |
encoder_attention_mask=image_atts, | |
return_dict=False, | |
mode='tagging', | |
) | |
return sigmoid(model.fc(tagging_embed).squeeze(-1)) | |
def forward_tag2text( | |
model: Module, | |
class_idxs: List[int], | |
imgs: Tensor | |
) -> Tensor: | |
image_embeds = model.visual_encoder(imgs.to(device)) | |
image_atts = torch.ones( | |
image_embeds.size()[:-1], dtype=torch.long).to(device) | |
label_embed = model.label_embed.weight.unsqueeze(0)\ | |
.repeat(imgs.shape[0], 1, 1) | |
tagging_embed, _ = model.tagging_head( | |
encoder_embeds=label_embed, | |
encoder_hidden_states=image_embeds, | |
encoder_attention_mask=image_atts, | |
return_dict=False, | |
mode='tagging', | |
) | |
return sigmoid(model.fc(tagging_embed))[:, class_idxs] | |
def print_write(f: TextIO, s: str): | |
print(s) | |
f.write(s + "\n") | |
if __name__ == "__main__": | |
args = parse_args() | |
# set up output paths | |
output_dir = args.output_dir | |
Path(output_dir).mkdir(parents=True, exist_ok=True) | |
pred_file, pr_file, ap_file, summary_file, logit_file = [ | |
output_dir + "/" + name for name in | |
("pred.txt", "pr.txt", "ap.txt", "summary.txt", "logits.pth") | |
] | |
with open(summary_file, "w", encoding="utf-8") as f: | |
print_write(f, "****************") | |
for key in ( | |
"model_type", "backbone", "checkpoint", "open_set", | |
"dataset", "input_size", | |
"threshold", "threshold_file", | |
"output_dir", "batch_size", "num_workers" | |
): | |
print_write(f, f"{key}: {getattr(args, key)}") | |
print_write(f, "****************") | |
# prepare data | |
loader, info = load_dataset( | |
dataset=args.dataset, | |
model_type=args.model_type, | |
input_size=args.input_size, | |
batch_size=args.batch_size, | |
num_workers=args.num_workers | |
) | |
taglist, imglist, annot_file, img_root = \ | |
info["taglist"], info["imglist"], info["annot_file"], info["img_root"] | |
# get class idxs | |
class_idxs = get_class_idxs( | |
model_type=args.model_type, | |
open_set=args.open_set, | |
taglist=taglist | |
) | |
# set up threshold(s) | |
thresholds = load_thresholds( | |
threshold=args.threshold, | |
threshold_file=args.threshold_file, | |
model_type=args.model_type, | |
open_set=args.open_set, | |
class_idxs=class_idxs, | |
num_classes=len(taglist) | |
) | |
# inference | |
if Path(logit_file).is_file(): | |
logits = torch.load(logit_file) | |
else: | |
# load model | |
if args.model_type == "ram": | |
model = load_ram( | |
backbone=args.backbone, | |
checkpoint=args.checkpoint, | |
input_size=args.input_size, | |
taglist=taglist, | |
open_set=args.open_set, | |
class_idxs=class_idxs | |
) | |
else: | |
model = load_tag2text( | |
backbone=args.backbone, | |
checkpoint=args.checkpoint, | |
input_size=args.input_size | |
) | |
# inference | |
logits = torch.empty(len(imglist), len(taglist)) | |
pos = 0 | |
for imgs in tqdm(loader, desc="inference"): | |
if args.model_type == "ram": | |
out = forward_ram(model, imgs) | |
else: | |
out = forward_tag2text(model, class_idxs, imgs) | |
bs = imgs.shape[0] | |
logits[pos:pos+bs, :] = out.cpu() | |
pos += bs | |
# save logits, making threshold-tuning super fast | |
torch.save(logits, logit_file) | |
# filter with thresholds | |
pred_tags = [] | |
for scores in logits.tolist(): | |
pred_tags.append([ | |
taglist[i] for i, s in enumerate(scores) if s >= thresholds[i] | |
]) | |
# generate result file | |
gen_pred_file(imglist, pred_tags, img_root, pred_file) | |
# evaluate and record | |
mAP, APs = get_mAP(logits.numpy(), annot_file, taglist) | |
CP, CR, Ps, Rs = get_PR(pred_file, annot_file, taglist) | |
with open(ap_file, "w", encoding="utf-8") as f: | |
f.write("Tag,AP\n") | |
for tag, AP in zip(taglist, APs): | |
f.write(f"{tag},{AP*100.0:.2f}\n") | |
with open(pr_file, "w", encoding="utf-8") as f: | |
f.write("Tag,Precision,Recall\n") | |
for tag, P, R in zip(taglist, Ps, Rs): | |
f.write(f"{tag},{P*100.0:.2f},{R*100.0:.2f}\n") | |
with open(summary_file, "w", encoding="utf-8") as f: | |
print_write(f, f"mAP: {mAP*100.0}") | |
print_write(f, f"CP: {CP*100.0}") | |
print_write(f, f"CR: {CR*100.0}") | |