import gradio as gr import torch from torch import nn from huggingface_hub import hf_hub_download from transformers import BertModel, BertTokenizer, CLIPModel, BertConfig, CLIPConfig, CLIPProcessor from modeling_unimo import UnimoForMaskedLM def load_dict_text(path): with open(path, 'r') as f: load_data = {} lines = f.readlines() for line in lines: key, value = line.split('\t') load_data[key] = value.replace('\n', '') return load_data def load_text(path): with open(path, 'r') as f: lines = f.readlines() load_data = [] for line in lines: load_data.append(line.strip().replace('\n', '')) return load_data class MKGformerModel(nn.Module): def __init__(self, text_config, vision_config): super().__init__() self.model = UnimoForMaskedLM(text_config, vision_config) def farword(self, batch): return self.model(**batch, return_dict=True) # tokenizer tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # entity and relation ent2text = load_dict_text('./dataset/MarKG/entity2text.txt') rel2text = load_dict_text('./dataset/MarKG/relation2text.txt') analogy_entities = load_text('./dataset/MARS/analogy_entities.txt') analogy_relations = load_text('./dataset/MARS/analogy_relations.txt') ent2description = load_dict_text('./dataset/MarKG/entity2textlong.txt') text2ent = {text: ent for ent, text in ent2text.items()} ent2token = {ent: f"[ENTITY_{i}]" for i, ent in enumerate(ent2description)} rel2token = {rel: f"[RELATION_{i}]" for i, rel in enumerate(rel2text)} analogy_ent2token = {ent : f"[ENTITY_{i}]" for i, ent in enumerate(ent2description) if ent in analogy_entities} analogy_rel2token = {rel : f"[RELATION_{i}]" for i, rel in enumerate(rel2text) if rel in analogy_relations} entity_list = list(ent2token.values()) relation_list = list(rel2token.values()) analogy_ent_list = list(analogy_ent2token.values()) analogy_rel_list = list(analogy_rel2token.values()) num_added_tokens = tokenizer.add_special_tokens({'additional_special_tokens': entity_list}) num_added_tokens = tokenizer.add_special_tokens({'additional_special_tokens': relation_list}) vocab = tokenizer.get_added_vocab() # dict: word: idx relation_id_st = vocab[relation_list[0]] relation_id_ed = vocab[relation_list[-1]] + 1 entity_id_st = vocab[entity_list[0]] entity_id_ed = vocab[entity_list[-1]] + 1 # analogy entities and relations analogy_entity_ids = [vocab[ent] for ent in analogy_ent_list] analogy_relation_ids = [vocab[rel] for rel in analogy_rel_list] num_added_tokens = tokenizer.add_special_tokens({'additional_special_tokens': ["[R]"]}) # model checkpoint_path = hf_hub_download(repo_id='flow3rdown/mkgformer_mart_ft', filename="mkgformer_mart_ft", repo_type='model') clip_config = CLIPConfig.from_pretrained('openai/clip-vit-base-patch32').vision_config clip_config.device = 'cpu' bert_config = BertConfig.from_pretrained('bert-base-uncased') mkgformer = MKGformerModel(clip_config, bert_config) mkgformer.model.resize_token_embeddings(len(tokenizer)) mkgformer.load_state_dict(torch.load(checkpoint_path, map_location='cpu')["state_dict"]) # processor processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32') def single_inference_iit(head_img, head_id, tail_img, tail_id, question_txt, question_id): # (I, I) -> (T, ?) ques_ent_text = ent2description[question_id] inputs = tokenizer( tokenizer.sep_token.join([analogy_ent2token[head_id] + " ", "[R] ", analogy_ent2token[tail_id] + " "]), tokenizer.sep_token.join([analogy_ent2token[question_id] + " " + ques_ent_text, "[R] ", "[MASK]"]), truncation="longest_first", max_length=128, padding="longest", return_tensors='pt', add_special_tokens=True) sep_idx = [[i for i, ids in enumerate(input_ids) if ids == tokenizer.sep_token_id] for input_ids in inputs['input_ids']] inputs['sep_idx'] = torch.tensor(sep_idx) inputs['attention_mask'] = inputs['attention_mask'].unsqueeze(1).expand([inputs['input_ids'].size(0), inputs['input_ids'].size(1), inputs['input_ids'].size(1)]).clone() for i, idx in enumerate(sep_idx): inputs['attention_mask'][i, :idx[2], idx[2]:] = 0 # image pixel_values = processor(images=[head_img, tail_img], return_tensors='pt')['pixel_values'].squeeze() inputs['pixel_values'] = pixel_values.unsqueeze(0) input_ids = inputs['input_ids'] model_output = mkgformer.model(**inputs, return_dict=True) logits = model_output[0].logits bsz = input_ids.shape[0] _, mask_idx = (input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True) # bsz mask_logits = logits[torch.arange(bsz), mask_idx][:, analogy_entity_ids] # bsz, 1, entity answer = ent2text[list(analogy_ent2token.keys())[mask_logits.argmax().item()]] return answer def single_inference_tti(head_txt, head_id, tail_txt, tail_id, question_img, question_id): # (T, T) -> (I, ?) head_ent_text, tail_ent_text = ent2description[head_id], ent2description[tail_id] inputs = tokenizer( tokenizer.sep_token.join([analogy_ent2token[head_id] + " " + head_ent_text, "[R] ", analogy_ent2token[tail_id] + " " + tail_ent_text]), tokenizer.sep_token.join([analogy_ent2token[question_id] + " ", "[R] ", "[MASK]"]), truncation="longest_first", max_length=128, padding="longest", return_tensors='pt', add_special_tokens=True) sep_idx = [[i for i, ids in enumerate(input_ids) if ids == tokenizer.sep_token_id] for input_ids in inputs['input_ids']] inputs['sep_idx'] = torch.tensor(sep_idx) inputs['attention_mask'] = inputs['attention_mask'].unsqueeze(1).expand([inputs['input_ids'].size(0), inputs['input_ids'].size(1), inputs['input_ids'].size(1)]).clone() for i, idx in enumerate(sep_idx): inputs['attention_mask'][i, :idx[2], idx[2]:] = 0 # image pixel_values = processor(images=question_img, return_tensors='pt')['pixel_values'].unsqueeze(1) pixel_values = torch.cat((pixel_values, torch.zeros_like(pixel_values)), dim=1) inputs['pixel_values'] = pixel_values input_ids = inputs['input_ids'] model_output = mkgformer.model(**inputs, return_dict=True) logits = model_output[0].logits bsz = input_ids.shape[0] _, mask_idx = (input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True) # bsz mask_logits = logits[torch.arange(bsz), mask_idx][:, analogy_entity_ids] # bsz, 1, entity answer = ent2text[list(analogy_ent2token.keys())[mask_logits.argmax().item()]] return answer def blended_inference_iti(head_img, head_id, tail_txt, tail_id, question_img, question_id): # (I, T) -> (I, ?) head_ent_text, tail_ent_text = ent2description[head_id], ent2description[tail_id] inputs = tokenizer( tokenizer.sep_token.join([analogy_ent2token[head_id], "[R] ", analogy_ent2token[tail_id] + " " + tail_ent_text]), tokenizer.sep_token.join([analogy_ent2token[question_id] + " ", "[R] ", "[MASK]"]), truncation="longest_first", max_length=128, padding="longest", return_tensors='pt', add_special_tokens=True) sep_idx = [[i for i, ids in enumerate(input_ids) if ids == tokenizer.sep_token_id] for input_ids in inputs['input_ids']] inputs['sep_idx'] = torch.tensor(sep_idx) inputs['attention_mask'] = inputs['attention_mask'].unsqueeze(1).expand([inputs['input_ids'].size(0), inputs['input_ids'].size(1), inputs['input_ids'].size(1)]).clone() for i, idx in enumerate(sep_idx): inputs['attention_mask'][i, :idx[2], idx[2]:] = 0 # image pixel_values = processor(images=[head_img, question_img], return_tensors='pt')['pixel_values'].squeeze() inputs['pixel_values'] = pixel_values.unsqueeze(0) input_ids = inputs['input_ids'] model_output = mkgformer.model(**inputs, return_dict=True) logits = model_output[0].logits bsz = input_ids.shape[0] _, mask_idx = (input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True) # bsz mask_logits = logits[torch.arange(bsz), mask_idx][:, analogy_entity_ids] # bsz, 1, entity answer = ent2text[list(analogy_ent2token.keys())[mask_logits.argmax().item()]] return answer def single_tab_iit(): with gr.Column(): gr.Markdown(""" $(I_h, I_t) : (T_q, ?)$ """) with gr.Row(): with gr.Column(): head_image = gr.Image(type='pil', label="Head Image") head_ent = gr.Textbox(lines=1, label="Head Entity") with gr.Column(): tail_image = gr.Image(type='pil', label="Tail Image") tail_ent = gr.Textbox(lines=1, label="Tail Entity") with gr.Column(): question_text = gr.Textbox(lines=1, label="Question Name") question_ent = gr.Textbox(lines=1, label="Question Entity") submit_btn = gr.Button("Submit") output_text = gr.Textbox(label="Output") submit_btn.click(fn=single_inference_iit, inputs=[head_image, head_ent, tail_image, tail_ent, question_text, question_ent], outputs=[output_text]) examples=[['examples/tree.jpg', 'Q10884', 'examples/forest.jpg', 'Q4421', "Anhui", 'Q40956']] ex = gr.Examples( examples=examples, fn=single_inference_iit, inputs=[head_image, head_ent, tail_image, tail_ent, question_text, question_ent], outputs=[output_text], cache_examples=False, run_on_click=False ) def single_tab_tti(): with gr.Column(): gr.Markdown(""" $(T_h, T_t) : (I_q, ?)$ """) with gr.Row(): with gr.Column(): head_text = gr.Textbox(lines=1, label="Head Name") head_ent = gr.Textbox(lines=1, label="Head Entity") with gr.Column(): tail_text = gr.Textbox(lines=1, label="Tail Name") tail_ent = gr.Textbox(lines=1, label="Tail Entity") with gr.Column(): question_image = gr.Image(type='pil', label="Question Image") question_ent = gr.Textbox(lines=1, label="Question Entity") submit_btn = gr.Button("Submit") output_text = gr.Textbox(label="Output") submit_btn.click(fn=single_inference_tti, inputs=[head_text, head_ent, tail_text, tail_ent, question_image, question_ent], outputs=[output_text]) examples=[['scrap', 'Q3217573', 'watch', 'Q178794', 'examples/root.jpg', 'Q111029']] ex = gr.Examples( examples=examples, fn=single_inference_iit, inputs=[head_text, head_ent, tail_text, tail_ent, question_image, question_ent], outputs=[output_text], cache_examples=False, run_on_click=False ) def blended_tab_iti(): with gr.Column(): gr.Markdown(""" $(I_h, T_t) : (I_q, ?)$ """) with gr.Row(): with gr.Column(): head_image = gr.Image(type='pil', label="Head Image") head_ent = gr.Textbox(lines=1, label="Head Entity") with gr.Column(): tail_txt = gr.Textbox(lines=1, label="Tail Name") tail_ent = gr.Textbox(lines=1, label="Tail Entity") with gr.Column(): question_image = gr.Image(type='pil', label="Question Image") question_ent = gr.Textbox(lines=1, label="Question Entity") submit_btn = gr.Button("Submit") output_text = gr.Textbox(label="Output") submit_btn.click(fn=blended_inference_iti, inputs=[head_image, head_ent, tail_txt, tail_ent, question_image, question_ent], outputs=[output_text]) examples=[['examples/watermelon.jpg', 'Q38645', 'fruit', 'Q3314483', 'examples/coffee.jpeg', 'Q8486']] ex = gr.Examples( examples=examples, fn=single_inference_iit, inputs=[head_image, head_ent, tail_txt, tail_ent, question_image, question_ent], outputs=[output_text], cache_examples=False, run_on_click=False ) TITLE = """MKG Analogy""" with gr.Blocks() as block: with gr.Column(elem_id="col-container"): gr.HTML(TITLE) with gr.Tab("Single Analogical Reasoning"): single_tab_iit() single_tab_tti() with gr.Tab("Blended Analogical Reasoning"): blended_tab_iti() # gr.HTML(ARTICLE) block.queue(max_size=64).launch(enable_queue=True)