import copy import torch import math import torch.nn as nn from torch.nn.parameter import Parameter import random import numpy as np from load_weights import load_weight from sklearn.model_selection import train_test_split from transformers import GPT2TokenizerFast import pandas as pd from torch.utils.data import Dataset, DataLoader from transformers import AdamW, get_linear_schedule_with_warmup torch.manual_seed(42) import nltk nltk.download('punkt') from transformers import GPT2Tokenizer from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler import datetime import time import os os.environ["CUDA_LAUNCH_BLOCKING"] = "1" from tqdm import trange import gradio as gr import re def gelu(x): return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) class Conv1D(nn.Module): def __init__(self, nf, nx): super(Conv1D, self).__init__() self.nf = nf w = torch.empty(nx, nf) nn.init.normal_(w, std=0.02) self.weight = Parameter(w) self.bias = Parameter(torch.zeros(nf)) def forward(self, x): size_out = x.size()[:-1] + (self.nf,) x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) x = x.view(*size_out) return x class LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-12): """Construct a layernorm module in the TF style (epsilon inside the square root). """ super(LayerNorm, self).__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.bias = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps def forward(self, x): u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) return self.weight * x + self.bias class Attention(nn.Module): def __init__(self, nx, n_ctx, config, scale=False): super(Attention, self).__init__() n_state = nx # in Attention: n_state=768 (nx=n_embd) # [switch nx => n_state from Block to Attention to keep identical to TF implem] assert n_state % config.n_head == 0 self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) self.n_head = config.n_head self.split_size = n_state self.scale = scale self.c_attn = Conv1D(n_state * 3, nx) self.c_proj = Conv1D(n_state, nx) def _attn(self, q, k, v): w = torch.matmul(q, k) if self.scale: w = w / math.sqrt(v.size(-1)) nd, ns = w.size(-2), w.size(-1) b = self.bias[:, :, ns-nd:ns, :ns] w = w * b - 1e10 * (1 - b) w = nn.Softmax(dim=-1)(w) return torch.matmul(w, v) def merge_heads(self, x): x = x.permute(0, 2, 1, 3).contiguous() new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states def split_heads(self, x, k=False): new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states if k: return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) else: return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) def forward(self, x, layer_past=None): x = self.c_attn(x) query, key, value = x.split(self.split_size, dim=2) query = self.split_heads(query) key = self.split_heads(key, k=True) value = self.split_heads(value) if layer_past is not None: past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below key = torch.cat((past_key, key), dim=-1) value = torch.cat((past_value, value), dim=-2) present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking a = self._attn(query, key, value) a = self.merge_heads(a) a = self.c_proj(a) return a, present class MLP(nn.Module): def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) super(MLP, self).__init__() nx = config.n_embd self.c_fc = Conv1D(n_state, nx) self.c_proj = Conv1D(nx, n_state) self.act = gelu def forward(self, x): h = self.act(self.c_fc(x)) h2 = self.c_proj(h) return h2 class Block(nn.Module): def __init__(self, n_ctx, config, scale=False): super(Block, self).__init__() nx = config.n_embd self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) self.attn = Attention(nx, n_ctx, config, scale) self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) self.mlp = MLP(4 * nx, config) def forward(self, x, layer_past=None): a, present = self.attn(self.ln_1(x), layer_past=layer_past) x = x + a m = self.mlp(self.ln_2(x)) x = x + m return x, present class GPT2Model(nn.Module): def __init__(self, config): super(GPT2Model, self).__init__() self.n_layer = config.n_layer self.n_embd = config.n_embd self.n_vocab = config.vocab_size self.wte = nn.Embedding(config.vocab_size, config.n_embd) self.wpe = nn.Embedding(config.n_positions, config.n_embd) block = Block(config.n_ctx, config, scale=True) self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) def set_embeddings_weights(self, model_embeddings_weights): embed_shape = model_embeddings_weights.shape self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) self.decoder.weight = model_embeddings_weights # Tied weights def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None): if (input_ids >= self.n_vocab).any(): raise ValueError(f"Invalid token ID found in input_ids: {input_ids}") # print(f"input_ids: {input_ids}") # Debugging statement # print(f"Max input_id: {input_ids.max().item()}") # Debugging statement # print(f"Min input_id: {input_ids.min().item()}") # Debugging statement if past is None: past_length = 0 past = [None] * len(self.h) else: past_length = past[0][0].size(-2) if position_ids is None: position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) input_shape = input_ids.size() input_ids = input_ids.view(-1, input_ids.size(-1)) position_ids = position_ids.view(-1, position_ids.size(-1)) inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) # print(f"inputs_embeds shape: {inputs_embeds.shape}") # print(f"position_embeds shape: {position_embeds.shape}") if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) token_type_embeds = self.wte(token_type_ids) else: token_type_embeds = 0 hidden_states = inputs_embeds + position_embeds + token_type_embeds presents = [] for block, layer_past in zip(self.h, past): hidden_states, present = block(hidden_states, layer_past) presents.append(present) hidden_states = self.ln_f(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) return hidden_states.view(*output_shape), presents class GPT2LMHead(nn.Module): def __init__(self, model_embeddings_weights, config): super(GPT2LMHead, self).__init__() self.n_embd = config.n_embd self.set_embeddings_weights(model_embeddings_weights) def set_embeddings_weights(self, model_embeddings_weights): embed_shape = model_embeddings_weights.shape self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) self.decoder.weight = model_embeddings_weights # Tied weights def forward(self, hidden_state): # Truncated Language modeling logits (we remove the last token) # h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd) lm_logits = self.decoder(hidden_state) return lm_logits import torch.nn.functional as F def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (batch size, vocabulary size) top_k > 0: keep only top k tokens with highest probability (top-k filtering). top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). filter_value: value to replace filtered logits. """ assert logits.dim() == 2 # batch size x vocabulary size top_k = min(top_k, logits.size(-1)) # Safety check if top_k > 0: # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value if top_p > 0.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] logits[indices_to_remove] = filter_value return logits class GPT2LMHeadModel(nn.Module): def __init__(self, config): super(GPT2LMHeadModel, self).__init__() self.transformer = GPT2Model(config) self.lm_head = GPT2LMHead(self.transformer.wte.weight, config) def set_tied(self): """ Make sure we are sharing the embeddings """ self.lm_head.set_embeddings_weights(self.transformer.wte.weight) def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None): hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past) lm_logits = self.lm_head(hidden_states) outputs = (lm_logits,presents) if lm_labels is not None: shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = lm_labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) outputs = (loss,) + outputs return outputs import torch.nn.functional as F def generate( self, input_ids, max_length, temperature=1.0, top_k=0, top_p=0.9, repetition_penalty=1.0, device='cuda' ): self.eval() input_ids = input_ids.to(device) batch_size = input_ids.shape[0] past = None generated = input_ids with torch.no_grad(): for _ in range(max_length): outputs = self(input_ids, past=past) next_token_logits = outputs[0][:, -1, :] past = outputs[1] for i in range(batch_size): for token_id in set(generated[i].tolist()): next_token_logits[i, token_id] /= repetition_penalty next_token_logits = next_token_logits / temperature filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) generated = torch.cat((generated, next_token), dim=1) if (next_token == self.config.eos_token_id).all(): break input_ids = next_token return generated class GPT2Config(object): def __init__( self, vocab_size_or_config_json_file=50257, n_positions=1024, n_ctx=1024, n_embd=768, n_layer=12, n_head=12, layer_norm_epsilon=1e-5, initializer_range=0.02, ): self.vocab_size = vocab_size_or_config_json_file self.n_ctx = n_ctx self.n_positions = n_positions self.n_embd = n_embd self.n_layer = n_layer self.n_head = n_head self.layer_norm_epsilon = layer_norm_epsilon self.initializer_range = initializer_range device = torch.device("cuda" if torch.cuda.is_available() else "cpu") config = GPT2Config() model = GPT2LMHeadModel(config) state_dict = torch.load(r'C:\vision_model\gpt-2-Pytorch\test\gpt_today\weights\epoch_1.pth', map_location='cpu' if not torch.cuda.is_available() else None) model = load_weight(model, state_dict) model.to(device) print(model) model.eval() tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer.pad_token = tokenizer.eos_token def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (batch size x vocabulary size) top_k > 0: keep only top k tokens with highest probability (top-k filtering). top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) """ assert logits.dim() == 2, "Expected logits dimension to be 2 (batch size x vocabulary size)" top_k = min(top_k, logits.size(-1)) # Safety check if top_k > 0: # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value if top_p > 0.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(nn.Softmax(dim=-1)(sorted_logits), dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 # Ensure that the dimensions match if sorted_indices_to_remove.size() != sorted_indices.size(): raise ValueError(f"Size mismatch: {sorted_indices_to_remove.size()} vs {sorted_indices.size()}") indices_to_remove = sorted_indices[sorted_indices_to_remove] # Expand dimensions to match logits tensor and use scatter_ for batch_idx in range(logits.size(0)): logits[batch_idx, indices_to_remove[batch_idx]] = filter_value return logits # prompt_text = "What is the classical conceptualisation of oxidation and reduction in redox reactions?" # prompt = f"\n<|startoftext|>[WP] {prompt_text} \n[RESPONSE]" # input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) # max_length = 50 # temperature = 0.7 # top_k = 50 # top_p = 0.95 # repetition_penalty = 1.0 # with torch.no_grad(): # for _ in range(max_length): # outputs = model(input_ids) # logits = outputs[0] # next_token_logits = logits[:, -1, :] / temperature # # Apply repetition penalty # for i in range(input_ids.size(0)): # for token_id in set(input_ids[i].tolist()): # next_token_logits[0, token_id] /= repetition_penalty # # Filter logits using top-k and/or top-p filtering # filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) # next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) # input_ids = torch.cat([input_ids, next_token], dim=-1).to(device) # import re # # generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True) # # wp_responses = re.split(r"\[WP\].*?\n|\[RESPONSE\]", generated_text)[1:] # print(input_ids[0]) # generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True) # wp_responses = re.split(r"\[WP\].*?\n|\[RESPONSE\]", generated_text)[1:] # print(wp_responses) # Define the generation function def generate_text(prompt_text, max_length=50, temperature=0.7, top_k=50, top_p=0.95, repetition_penalty=1.0): prompt = f"\n[WP] {prompt_text} \n[RESPONSE]" input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) with torch.no_grad(): for _ in range(max_length): outputs = model(input_ids) logits = outputs[0] next_token_logits = logits[:, -1, :] / temperature # Apply repetition penalty for i in range(input_ids.size(0)): for token_id in set(input_ids[i].tolist()): next_token_logits[0, token_id] /= repetition_penalty # Filter logits using top-k and/or top-p filtering filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) input_ids = torch.cat([input_ids, next_token], dim=-1).to(device) generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True) wp_responses = re.split(r"\[WP\].*?\n|\[RESPONSE\]", generated_text)[1:] return wp_responses[1] # Define the Gradio interface using Blocks with gr.Blocks() as demo: with gr.Row(): gr.Markdown("

GPT-2 Text Generator

") with gr.Row(): with gr.Column(): prompt = gr.Textbox(lines=2, placeholder="Enter prompt here...", label="Prompt") max_length = gr.Slider(minimum=10, maximum=100, step=1, value=50, label="Max Length") temperature = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.7, label="Temperature") top_k = gr.Slider(minimum=0, maximum=100, step=1, value=50, label="Top K") top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.95, label="Top P") repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, step=0.1, value=1.0, label="Repetition Penalty") generate_button = gr.Button("Generate") with gr.Column(): output_text = gr.Textbox(lines=20, label="Generated Text") generate_button.click( fn=generate_text, inputs=[prompt, max_length, temperature, top_k, top_p, repetition_penalty], outputs=output_text ) demo.launch()