File size: 3,843 Bytes
717dd41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from transformers import (
    AutoConfig,
    AutoTokenizer,
    BitsAndBytesConfig,
    AutoProcessor, 
    LlamaForCausalLM,
    MllamaForConditionalGeneration,
    AutoModelForCausalLM
)
import torch
from peft import PeftModel
from datasets import load_from_disk
import pandas as pd
from tqdm import tqdm
from torch.utils.data import DataLoader


mode_path = '/gemini/pretrain/meta-llamaLlama-3.2-11B-Vision-Instruct'
lora_path = '/gemini/code/FMD/model/final_model_4/checkpoint-2440' # lora 输出对应 checkpoint 路径

# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(mode_path, trust_remote_code=True)

# 加载模型
model = MllamaForConditionalGeneration.from_pretrained(mode_path, device_map="auto",torch_dtype=torch.bfloat16, trust_remote_code=True).eval()

# 加载lora权重
model = PeftModel.from_pretrained(model, model_id=lora_path)
test_dataset = load_from_disk("/gemini/code/FMD/final_dataset/Test")
results = []
with torch.no_grad():
    for data in tqdm(test_dataset):
        model_input = tokenizer(
            data['instruction_1'],                  # 输入文本
            add_special_tokens=False,     # 不添加特殊标记
            truncation=True,              # 启用截断
            max_length=3000              # 设置最大长度
        )
        model_input = tokenizer.decode(model_input["input_ids"], skip_special_tokens=False)
        
        model_inputs = tokenizer(f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are an expert in financial misinformation detection.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{model_input}\nimage information: {data['image_info']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", truncation=True, max_length=3600, add_special_tokens=False,return_tensors="pt").to('cuda')
        # 生成模型输出
        generated_ids = model.generate(**model_inputs, max_new_tokens=1024)
        
        # 去除输入部分的 token,以保留生成的预测结果
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]
        
        # 解码生成的预测结果
        responses = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        print(responses)
        # 将每个结果按顺序存储到列表中
        results.append({
            "ID": data['ID'],
            "response": responses
        })
def split_response(text):
    #获取Prediction的内容
    prediction_pattern = r"Prediction:\s*(False|True|NEI)\s*$"
    prediction_match = re.search(prediction_pattern, text, re.MULTILINE)
    if prediction_match:
        prediction = prediction_match.group(1).strip()
    else:
        prediction = 'None'
        print("没有找到匹配的内容")
    #获取Explanation的内容
    explanation_pattern = r"Explanation:\s*(.*)"
    explanation_match = re.search(explanation_pattern, text, re.MULTILINE)
    if explanation_match:
        explanation = explanation_match.group(1).strip()
    else:
        explanation = None  # 如果没有匹配项,设置为 None
    return prediction, explanation

if results:
    df = pd.DataFrame(results)
    
    for index, row in df.iterrows():
        text = row['response']
        prediction, explanation= split_response(text)
        df.at[index, 'Prediction'] = prediction
        df.at[index, 'Explanation'] = explanation

    df['ID'] = df['ID'].str.replace('FMD_test_', '', regex=False)
    df = df.rename(columns={'ID': 'id','Prediction': 'pred','Explanation': 'explanation'})
    df = df.drop('response',axis=1)
    mapping = {
        'False': 0,
        'True': 1,
        'NEI': 2
    }
    df['pred'] = df['pred'].replace(mapping)
    df.to_csv("/gemini/code/FMD/inference/result_final_model_4/result.csv",index = False)