Spaces:
Sleeping
Sleeping
File size: 5,113 Bytes
1040e55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import os
import csv
import json
import torch
import argparse
import pandas as pd
import torch.nn as nn
from tqdm import tqdm
from collections import defaultdict
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from torch.utils.data import DataLoader
from mplug_owl_video.modeling_mplug_owl import MplugOwlForConditionalGeneration
from mplug_owl_video.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor
from peft import LoraConfig, get_peft_model
from data_utils.xgpt3_dataset import MultiModalDataset
from utils import batchify
parser = argparse.ArgumentParser()
parser.add_argument('--input_csv', type = str, required = True, help = 'input json file')
parser.add_argument('--output_csv', type = str, help = 'output csv with scores')
parser.add_argument('--pretrained_ckpt', type = str, required = True, help = 'pretrained ckpt')
parser.add_argument('--trained_ckpt', type = str, help = 'trained ckpt')
parser.add_argument('--lora_r', type = int, default = 32)
parser.add_argument('--use_lora', action = 'store_true', help = 'lora model')
parser.add_argument('--all-params', action = 'store_true', help = 'use all params of the model')
parser.add_argument('--batch_size', type = int, default = 32)
args = parser.parse_args()
softmax = nn.Softmax(dim=2)
def get_entail(logits, input_ids, tokenizer):
logits = softmax(logits)
token_id_yes = tokenizer.encode('Yes', add_special_tokens = False)[0]
token_id_no = tokenizer.encode('No', add_special_tokens = False)[0]
entailment = []
for j in range(len(logits)):
for i in range(len(input_ids[j])):
if input_ids[j][i] == tokenizer.pad_token_id: # pad token if the answer is not present
i = i - 1
break
elif i == len(input_ids[j]) - 1:
break
score = logits[j][i][token_id_yes] / (logits[j][i][token_id_yes] + logits[j][i][token_id_no])
entailment.append(score)
entailment = torch.stack(entailment)
return entailment
def get_scores(model, tokenizer, dataloader):
with torch.no_grad():
for index, inputs in tqdm(enumerate(dataloader)):
for k, v in inputs.items():
if torch.is_tensor(v):
if v.dtype == torch.float:
inputs[k] = v.bfloat16()
inputs[k] = inputs[k].to(model.device)
outputs = model(pixel_values = inputs['pixel_values'], video_pixel_values = inputs['video_pixel_values'], labels = None, \
num_images = inputs['num_images'], num_videos = inputs['num_videos'], input_ids = inputs['input_ids'], non_padding_mask = inputs['non_padding_mask'], \
non_media_mask = inputs['non_media_mask'], prompt_mask = inputs['prompt_mask'])
logits = outputs['logits']
entail_scores = get_entail(logits, inputs['input_ids'], tokenizer)
for m in range(len(entail_scores)):
with open(args.output_csv, 'a') as f:
writer = csv.writer(f)
writer.writerow([inputs['videopaths'][m], inputs['captions'][m], entail_scores[m].item()])
print(f"Batch {index} Done")
def main():
pretrained_ckpt = args.pretrained_ckpt
# Processors
tokenizer = LlamaTokenizer.from_pretrained(pretrained_ckpt)
image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt)
processor = MplugOwlProcessor(image_processor, tokenizer)
valid_data = MultiModalDataset(args.input_csv, tokenizer, processor, max_length = 256, loss_objective = 'sequential')
dataloader = DataLoader(valid_data, batch_size=args.batch_size, pin_memory=True, collate_fn=batchify)
# Instantiate model
model = MplugOwlForConditionalGeneration.from_pretrained(
pretrained_ckpt,
torch_dtype=torch.bfloat16,
device_map={'':0}
)
if args.use_lora:
for name, param in model.named_parameters():
param.requires_grad = False
if args.all_params:
peft_config = LoraConfig(
target_modules=r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj|gate_proj|down_proj|up_proj)',
inference_mode=True,
r=args.lora_r,
lora_alpha=16,
lora_dropout=0.05
)
else:
peft_config = LoraConfig(
target_modules=r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj)',
inference_mode=True,
r=args.lora_r,
lora_alpha=16,
lora_dropout=0.05
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
with open(args.trained_ckpt, 'rb') as f:
ckpt = torch.load(f, map_location = torch.device(f"cuda:0"))
model.load_state_dict(ckpt)
model = model.to(torch.bfloat16)
print('Model Loaded')
model.eval()
get_scores(model, tokenizer, dataloader)
if __name__ == "__main__":
main()
|