|
from transformers import ( |
|
AutoConfig, |
|
AutoTokenizer, |
|
BitsAndBytesConfig, |
|
AutoProcessor, |
|
LlamaForCausalLM, |
|
MllamaForConditionalGeneration, |
|
AutoModelForCausalLM |
|
) |
|
import torch |
|
from peft import PeftModel |
|
from datasets import load_from_disk |
|
import pandas as pd |
|
from tqdm import tqdm |
|
from torch.utils.data import DataLoader |
|
|
|
|
|
mode_path = '/gemini/pretrain/meta-llamaLlama-3.2-11B-Vision-Instruct' |
|
lora_path = '/gemini/code/FMD/model/final_model_4/checkpoint-2440' |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(mode_path, trust_remote_code=True) |
|
|
|
|
|
model = MllamaForConditionalGeneration.from_pretrained(mode_path, device_map="auto",torch_dtype=torch.bfloat16, trust_remote_code=True).eval() |
|
|
|
|
|
model = PeftModel.from_pretrained(model, model_id=lora_path) |
|
test_dataset = load_from_disk("/gemini/code/FMD/final_dataset/Test") |
|
results = [] |
|
with torch.no_grad(): |
|
for data in tqdm(test_dataset): |
|
model_input = tokenizer( |
|
data['instruction_1'], |
|
add_special_tokens=False, |
|
truncation=True, |
|
max_length=3000 |
|
) |
|
model_input = tokenizer.decode(model_input["input_ids"], skip_special_tokens=False) |
|
|
|
model_inputs = tokenizer(f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are an expert in financial misinformation detection.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{model_input}\nimage information: {data['image_info']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", truncation=True, max_length=3600, add_special_tokens=False,return_tensors="pt").to('cuda') |
|
|
|
generated_ids = model.generate(**model_inputs, max_new_tokens=1024) |
|
|
|
|
|
generated_ids = [ |
|
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) |
|
] |
|
|
|
|
|
responses = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
print(responses) |
|
|
|
results.append({ |
|
"ID": data['ID'], |
|
"response": responses |
|
}) |
|
def split_response(text): |
|
|
|
prediction_pattern = r"Prediction:\s*(False|True|NEI)\s*$" |
|
prediction_match = re.search(prediction_pattern, text, re.MULTILINE) |
|
if prediction_match: |
|
prediction = prediction_match.group(1).strip() |
|
else: |
|
prediction = 'None' |
|
print("没有找到匹配的内容") |
|
|
|
explanation_pattern = r"Explanation:\s*(.*)" |
|
explanation_match = re.search(explanation_pattern, text, re.MULTILINE) |
|
if explanation_match: |
|
explanation = explanation_match.group(1).strip() |
|
else: |
|
explanation = None |
|
return prediction, explanation |
|
|
|
if results: |
|
df = pd.DataFrame(results) |
|
|
|
for index, row in df.iterrows(): |
|
text = row['response'] |
|
prediction, explanation= split_response(text) |
|
df.at[index, 'Prediction'] = prediction |
|
df.at[index, 'Explanation'] = explanation |
|
|
|
df['ID'] = df['ID'].str.replace('FMD_test_', '', regex=False) |
|
df = df.rename(columns={'ID': 'id','Prediction': 'pred','Explanation': 'explanation'}) |
|
df = df.drop('response',axis=1) |
|
mapping = { |
|
'False': 0, |
|
'True': 1, |
|
'NEI': 2 |
|
} |
|
df['pred'] = df['pred'].replace(mapping) |
|
df.to_csv("/gemini/code/FMD/inference/result_final_model_4/result.csv",index = False) |