MoMA_demo / model_lib /utils.py
Kunpeng Song
recreate
8a4a948
raw
history blame contribute delete
No virus
1.12 kB
import argparse
import torch
from torchvision.transforms import ToPILImage
from PIL import Image
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of MoMA.")
parser.add_argument("--load_attn_adapters",type=str,default="checkpoints/attn_adapters_projectors.th",help="self_cross attentions and LLM projectors.")
parser.add_argument("--output_path",type=str,default="output",help="output directory.")
parser.add_argument("--model_path",type=str,default="KunpengSong/MoMA_llava_7b",help="fine tuned llava (Multi-modal LLM decoder)")
args = parser.parse_known_args()[0]
args.device = torch.device("cuda", 0)
args.load_8bit, args.load_4bit = False, True
return args
def show_PIL_image(tensor):
# tensor of shape [3, 3, 512, 512]
to_pil = ToPILImage()
images = [to_pil(tensor[i]) for i in range(tensor.shape[0])]
concatenated_image = Image.new('RGB', (images[0].width * 3, images[0].height))
x_offset = 0
for img in images:
concatenated_image.paste(img, (x_offset, 0))
x_offset += img.width
return concatenated_image