|
import torch |
|
from tokenizers import Tokenizer |
|
from train import GPT, GPTConfig |
|
|
|
import torch.nn.functional as F |
|
|
|
def nucleus_sampling(logits, p=0.9): |
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
sorted_indices_to_remove = cumulative_probs > p |
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
sorted_indices_to_remove[..., 0] = 0 |
|
logits[sorted_indices[sorted_indices_to_remove]] = -float('Inf') |
|
probabilities = F.softmax(logits, dim=-1) |
|
next_token_id = torch.multinomial(probabilities, num_samples=1).item() |
|
return next_token_id |
|
|
|
def load_model_and_tokenizer(): |
|
|
|
config = GPTConfig() |
|
model = GPT(config) |
|
model.load_state_dict(torch.load('best_model.pt', map_location=torch.device('cpu'))) |
|
model.eval() |
|
tokenizer = Tokenizer.from_file("az_tokenizer.json") |
|
return model, tokenizer |
|
|
|
def apply_repetition_penalty(logits, input_ids, penalty=1.2): |
|
|
|
for token_id in set(input_ids): |
|
logits[0, token_id] /= penalty |
|
return logits |
|
|
|
def generate_text(model, tokenizer, prompt, max_new_tokens=50, temperature=0.001, p=0.95, repetition_penalty=1.5, device='cpu'): |
|
model = model.to(device) |
|
input_ids = tokenizer.encode(prompt).ids |
|
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(device) |
|
|
|
for _ in range(max_new_tokens): |
|
with torch.no_grad(): |
|
output_logits, _ = model(input_tensor) |
|
|
|
|
|
logits = output_logits[:, -1, :] / temperature |
|
|
|
|
|
logits = apply_repetition_penalty(logits.clone(), input_ids, penalty=repetition_penalty) |
|
|
|
|
|
next_token_id = nucleus_sampling(logits[0], p=p) |
|
|
|
input_ids.append(next_token_id) |
|
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(device) |
|
|
|
if next_token_id == tokenizer.token_to_id('[END]'): |
|
break |
|
|
|
generated_text = tokenizer.decode(input_ids) |
|
return generated_text.replace(' i ', ' ') |
|
|
|
|
|
def main(): |
|
model, tokenizer = load_model_and_tokenizer() |
|
prompt = "Azərbaycanın tarixi" |
|
generated_text = generate_text(model, tokenizer, prompt, p=0.9) |
|
print("Generated Text:", generated_text) |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|