damerajee commited on
Commit
9e1161f
·
verified ·
1 Parent(s): 8c80497

Update modeling_gpt2vision.py

Browse files
Files changed (1) hide show
  1. modeling_gpt2vision.py +7 -78
modeling_gpt2vision.py CHANGED
@@ -1,17 +1,16 @@
1
  import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel, AutoTokenizer
4
- from .configuration_gpt2vision import GPT2VisionConfig, GPT2Config
5
  from .vision_encoder import VisionEncoder
6
  from .modeling_gpt2 import GPT2LMHeadModel
7
 
8
-
9
  IMAGE_TOKEN = "<image>"
10
  ANSWER_EOS = "<|endoftext|>"
11
 
12
  def resize_token_embeds(model_name="openai-community/gpt2"):
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
- new_tokens={
15
  "additional_special_tokens": [IMAGE_TOKEN]
16
  }
17
  tokenizer.add_special_tokens(new_tokens)
@@ -19,30 +18,6 @@ def resize_token_embeds(model_name="openai-community/gpt2"):
19
 
20
  tokenizer = resize_token_embeds()
21
 
22
- def create_labels(input_ids, tokenizer, attention_mask):
23
- labels = input_ids.clone()
24
-
25
- labels[attention_mask == 0] = -100
26
-
27
- answer_start_tokens = tokenizer.encode("Answer:", add_special_tokens=False)
28
-
29
- for i, seq in enumerate(input_ids):
30
- # Find the start of the answer
31
- answer_start = (seq == answer_start_tokens[0]).nonzero(as_tuple=True)[0]
32
- if len(answer_start) > 0:
33
- answer_start = answer_start[0]
34
- if seq[answer_start:answer_start+len(answer_start_tokens)].tolist() == answer_start_tokens:
35
- # Mask out everything before the answer
36
- labels[i, :answer_start] = -100
37
-
38
- # Find the end of the sequence (last non-padding token)
39
- sequence_end = attention_mask[i].nonzero(as_tuple=True)[0][-1]
40
-
41
- # Keep the last token (EOS) as part of the label
42
- labels[i, sequence_end+1:] = -100
43
-
44
- return labels
45
-
46
  class MLP(nn.Module):
47
  def __init__(self, in_features: int, hidden_features: int = None, out_features: int = None):
48
  super().__init__()
@@ -53,12 +28,6 @@ class MLP(nn.Module):
53
  self.fc2 = nn.Linear(hidden_features, out_features)
54
  self.dropout = nn.Dropout(p=0.1)
55
 
56
- # Initialize weights
57
- nn.init.xavier_normal_(self.fc1.weight)
58
- nn.init.zeros_(self.fc1.bias)
59
- nn.init.xavier_normal_(self.fc2.weight)
60
- nn.init.zeros_(self.fc2.bias)
61
-
62
  def forward(self, x: torch.Tensor) -> torch.Tensor:
63
  x = self.fc1(x)
64
  x = self.act(x)
@@ -66,7 +35,6 @@ class MLP(nn.Module):
66
  x = self.fc2(x)
67
  return x
68
 
69
-
70
  class GPT2Vision(PreTrainedModel):
71
  config_class = GPT2VisionConfig
72
 
@@ -74,35 +42,21 @@ class GPT2Vision(PreTrainedModel):
74
  super().__init__(config)
75
  self.vision_encoder = VisionEncoder()
76
  self.mlp = MLP(in_features=768, hidden_features=768 * 4, out_features=768)
77
-
78
  self.language_model = GPT2LMHeadModel(config.gpt2_config)
79
-
80
  self.language_model.resize_token_embeddings(len(tokenizer))
81
-
82
  self.tokenizer = tokenizer
83
  tokenizer.pad_token = tokenizer.eos_token
84
-
85
  self.image_token_id = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
86
 
87
  @property
88
  def device(self):
89
  return next(self.language_model.parameters()).device
90
 
91
- def freeze_model_components(self, freeze_vision=True, freeze_language=True,freeze_mlp=True):
92
- for param in self.vision_encoder.parameters():
93
- param.requires_grad = not freeze_vision
94
- for param in self.language_model.parameters():
95
- param.requires_grad = not freeze_language
96
- for param in self.mlp.parameters():
97
- param.requires_grad = not freeze_mlp
98
-
99
  def tokenize_encode(self, batch, device):
100
  text = batch['text']
101
  images = batch['image']
102
-
103
  if isinstance(text, str):
104
  text = [text]
105
-
106
  input_texts = [f"{IMAGE_TOKEN}{t}" for t in text]
107
  text_inputs = self.tokenizer(
108
  input_texts,
@@ -112,55 +66,32 @@ class GPT2Vision(PreTrainedModel):
112
  return_tensors="pt",
113
  pad_to_multiple_of=8,
114
  ).to(device)
115
-
116
- pixel_values = self.vision_encoder(images,device)
117
-
118
  return {
119
  "input_ids": text_inputs.input_ids,
120
  "attention_mask": text_inputs.attention_mask,
121
  "pixel_values": pixel_values
122
  }
123
-
124
  def preprocess_inputs(self, batch):
125
  pixel_values = batch['pixel_values'].squeeze(1)
126
  input_ids = batch['input_ids'].squeeze(1)
127
  attention_mask = batch['attention_mask'].squeeze(1)
128
-
129
  input_ids = input_ids.to(self.device)
130
  attention_mask = attention_mask.to(self.device)
131
  pixel_values = pixel_values.to(self.device)
132
-
133
- labels = create_labels(input_ids, self.tokenizer, attention_mask)
134
- labels = labels.to(self.device)
135
-
136
  img_embs = self.mlp(pixel_values)
137
  tok_embs = self.language_model.get_input_embeddings()(input_ids)
138
-
139
  inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
140
-
141
  img_attention = torch.ones((img_embs.size(0), img_embs.size(1)), dtype=torch.long, device=self.device)
142
  attention_mask = torch.cat((attention_mask[:, 0:1], img_attention, attention_mask[:, 1:]), dim=1)
 
143
 
144
- img_labels = torch.full((labels.size(0), img_embs.size(1)), fill_value=-100, dtype=torch.long, device=self.device)
145
- labels = torch.cat((labels[:, 0:1], img_labels, labels[:, 1:]), dim=1)
146
- return inputs_embeds, attention_mask, input_ids, labels
147
-
148
- def forward(self, batch, **kwargs):
149
- inputs_embeds, attention_mask, input_ids, labels = self.preprocess_inputs(batch)
150
- outputs = self.language_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
151
-
152
-
153
- return outputs
154
-
155
  def generate(self, question, image, max_new_tokens=30, **kwargs):
156
- prompt = prompt = f"Question: {question}\nAnswer:"
157
  batch = {"image": [image], "text": prompt}
158
  encoded_batch = self.tokenize_encode(batch, self.device)
159
- inputs_embeds, attention_mask, input_ids, _ = self.preprocess_inputs(encoded_batch)
160
-
161
-
162
-
163
-
164
  output_sequences = self.language_model.generate(
165
  inputs_embeds=inputs_embeds,
166
  attention_mask=attention_mask,
@@ -168,9 +99,7 @@ class GPT2Vision(PreTrainedModel):
168
  eos_token_id=self.tokenizer.eos_token_id,
169
  max_new_tokens=max_new_tokens,
170
  repetition_penalty=1.0,
171
-
172
  **kwargs
173
  )
174
-
175
  output = self.tokenizer.decode(output_sequences[0], skip_special_tokens=True)
176
  return output
 
1
  import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel, AutoTokenizer
4
+ from .configuration_gpt2vision import GPT2VisionConfig
5
  from .vision_encoder import VisionEncoder
6
  from .modeling_gpt2 import GPT2LMHeadModel
7
 
 
8
  IMAGE_TOKEN = "<image>"
9
  ANSWER_EOS = "<|endoftext|>"
10
 
11
  def resize_token_embeds(model_name="openai-community/gpt2"):
12
  tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+ new_tokens = {
14
  "additional_special_tokens": [IMAGE_TOKEN]
15
  }
16
  tokenizer.add_special_tokens(new_tokens)
 
18
 
19
  tokenizer = resize_token_embeds()
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  class MLP(nn.Module):
22
  def __init__(self, in_features: int, hidden_features: int = None, out_features: int = None):
23
  super().__init__()
 
28
  self.fc2 = nn.Linear(hidden_features, out_features)
29
  self.dropout = nn.Dropout(p=0.1)
30
 
 
 
 
 
 
 
31
  def forward(self, x: torch.Tensor) -> torch.Tensor:
32
  x = self.fc1(x)
33
  x = self.act(x)
 
35
  x = self.fc2(x)
36
  return x
37
 
 
38
  class GPT2Vision(PreTrainedModel):
39
  config_class = GPT2VisionConfig
40
 
 
42
  super().__init__(config)
43
  self.vision_encoder = VisionEncoder()
44
  self.mlp = MLP(in_features=768, hidden_features=768 * 4, out_features=768)
 
45
  self.language_model = GPT2LMHeadModel(config.gpt2_config)
 
46
  self.language_model.resize_token_embeddings(len(tokenizer))
 
47
  self.tokenizer = tokenizer
48
  tokenizer.pad_token = tokenizer.eos_token
 
49
  self.image_token_id = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
50
 
51
  @property
52
  def device(self):
53
  return next(self.language_model.parameters()).device
54
 
 
 
 
 
 
 
 
 
55
  def tokenize_encode(self, batch, device):
56
  text = batch['text']
57
  images = batch['image']
 
58
  if isinstance(text, str):
59
  text = [text]
 
60
  input_texts = [f"{IMAGE_TOKEN}{t}" for t in text]
61
  text_inputs = self.tokenizer(
62
  input_texts,
 
66
  return_tensors="pt",
67
  pad_to_multiple_of=8,
68
  ).to(device)
69
+ pixel_values = self.vision_encoder(images, device)
 
 
70
  return {
71
  "input_ids": text_inputs.input_ids,
72
  "attention_mask": text_inputs.attention_mask,
73
  "pixel_values": pixel_values
74
  }
75
+
76
  def preprocess_inputs(self, batch):
77
  pixel_values = batch['pixel_values'].squeeze(1)
78
  input_ids = batch['input_ids'].squeeze(1)
79
  attention_mask = batch['attention_mask'].squeeze(1)
 
80
  input_ids = input_ids.to(self.device)
81
  attention_mask = attention_mask.to(self.device)
82
  pixel_values = pixel_values.to(self.device)
 
 
 
 
83
  img_embs = self.mlp(pixel_values)
84
  tok_embs = self.language_model.get_input_embeddings()(input_ids)
 
85
  inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
 
86
  img_attention = torch.ones((img_embs.size(0), img_embs.size(1)), dtype=torch.long, device=self.device)
87
  attention_mask = torch.cat((attention_mask[:, 0:1], img_attention, attention_mask[:, 1:]), dim=1)
88
+ return inputs_embeds, attention_mask, input_ids
89
 
 
 
 
 
 
 
 
 
 
 
 
90
  def generate(self, question, image, max_new_tokens=30, **kwargs):
91
+ prompt = f"Question: {question}\nAnswer:"
92
  batch = {"image": [image], "text": prompt}
93
  encoded_batch = self.tokenize_encode(batch, self.device)
94
+ inputs_embeds, attention_mask, input_ids = self.preprocess_inputs(encoded_batch)
 
 
 
 
95
  output_sequences = self.language_model.generate(
96
  inputs_embeds=inputs_embeds,
97
  attention_mask=attention_mask,
 
99
  eos_token_id=self.tokenizer.eos_token_id,
100
  max_new_tokens=max_new_tokens,
101
  repetition_penalty=1.0,
 
102
  **kwargs
103
  )
 
104
  output = self.tokenizer.decode(output_sequences[0], skip_special_tokens=True)
105
  return output