Spaces:
Sleeping
Sleeping
import os | |
import csv | |
import json | |
import torch | |
import argparse | |
import pandas as pd | |
import torch.nn as nn | |
from tqdm import tqdm | |
from collections import defaultdict | |
from transformers.models.llama.tokenization_llama import LlamaTokenizer | |
from torch.utils.data import DataLoader | |
from mplug_owl_video.modeling_mplug_owl import MplugOwlForConditionalGeneration | |
from mplug_owl_video.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor | |
from peft import LoraConfig, get_peft_model | |
from data_utils.xgpt3_dataset import MultiModalDataset | |
from utils import batchify | |
import gradio as gr | |
from entailment_inference import get_scores | |
print(f"Is CUDA available: {torch.cuda.is_available()}") | |
# True | |
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
# Tesla T4 | |
tokenizer = LlamaTokenizer.from_pretrained(pretrained_ckpt) | |
image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt) | |
processor = MplugOwlProcessor(image_processor, tokenizer) | |
# Instantiate model | |
model = MplugOwlForConditionalGeneration.from_pretrained( | |
pretrained_ckpt, | |
torch_dtype=torch.bfloat16, | |
device_map={'':0} | |
) | |
for name, param in model.named_parameters(): | |
param.requires_grad = False | |
peft_config = LoraConfig( | |
target_modules=r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj|gate_proj|down_proj|up_proj)', | |
inference_mode=True, | |
r=32, | |
lora_alpha=16, | |
lora_dropout=0.05 | |
) | |
model = get_peft_model(model, peft_config) | |
model.print_trainable_parameters() | |
with open(trained_ckpt, 'rb') as f: | |
ckpt = torch.load(f, map_location = torch.device(f"cuda:0")) | |
model.load_state_dict(ckpt) | |
model = model.to(torch.bfloat16) | |
print('Model Loaded') | |
PROMPT = """The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. | |
Human: <|video|> | |
Human: Does this video entail the description: ""A basketball team walking off the field while the audience claps.""? | |
AI: """ | |
valid_data = MultiModalDataset("examples/y5xuvHpDPZQ_000005_000015.mp4", PROMPT, tokenizer, processor, max_length = 256, loss_objective = 'sequential') | |
dataloader = DataLoader(valid_data, pin_memory=True, collate_fn=batchify) | |
score = get_scores(model, tokenizer, dataloader) | |
print(score) |