3ed0k4 commited on
Commit
ef404e7
·
verified ·
1 Parent(s): 075d99e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import streamlit as st
4
+ import torch
5
+ from src.model import TransformerModel
6
+ from src.utils import load_vocab, tokenize
7
+ import time
8
+ import random
9
+ import os
10
+
11
+ # Configuration
12
+ MODEL_PATH = 'models/3ed0k4_model_epoch10.pth' # Update this path based on the latest model
13
+ VOCAB_PATH = 'vocab.json'
14
+ EMBED_SIZE = 256
15
+ NUM_HEADS = 8
16
+ HIDDEN_DIM = 512
17
+ NUM_LAYERS = 4
18
+ DROPOUT = 0.1
19
+ MAX_LENGTH = 100 # Maximum tokens to generate
20
+
21
+ # Title and Description
22
+ st.title("3ed0k4 NLP Text Generation Model 🚀")
23
+ st.write("Enter a prompt, and the model will generate text based on your input. It will take 1 to 10 seconds to respond to simulate 'thinking'.")
24
+
25
+ # Load vocabulary
26
+ @st.cache_resource
27
+ def load_resources():
28
+ vocab = load_vocab(VOCAB_PATH)
29
+ return vocab
30
+
31
+ vocab = load_resources()
32
+ vocab_size = len(vocab)
33
+
34
+ # Initialize model
35
+ @st.cache_resource
36
+ def load_model():
37
+ model = TransformerModel(
38
+ vocab_size=vocab_size,
39
+ embed_size=EMBED_SIZE,
40
+ num_heads=NUM_HEADS,
41
+ hidden_dim=HIDDEN_DIM,
42
+ num_layers=NUM_LAYERS,
43
+ dropout=DROPOUT
44
+ )
45
+ if not os.path.exists(MODEL_PATH):
46
+ st.error(f"Model file not found at {MODEL_PATH}. Please ensure the model is trained and the path is correct.")
47
+ return None
48
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
49
+ model.eval()
50
+ return model
51
+
52
+ model = load_model()
53
+
54
+ def generate_text(prompt, max_length=MAX_LENGTH):
55
+ tokens = tokenize(prompt)
56
+ numericalized = [vocab.get(token, vocab['<UNK>']) for token in tokens]
57
+ input_seq = torch.tensor(numericalized, dtype=torch.long).unsqueeze(0) # Batch size 1
58
+
59
+ generated = numericalized.copy()
60
+
61
+ with torch.no_grad():
62
+ for _ in range(max_length):
63
+ src_mask = model.generate_square_subsequent_mask(input_seq.size(1)).to(input_seq.device)
64
+ outputs = model(input_seq, src_mask)
65
+ next_token_logits = outputs[0, -1, :]
66
+ next_token = torch.argmax(next_token_logits).item()
67
+
68
+ if next_token == vocab.get('<PAD>', 0):
69
+ break
70
+
71
+ generated.append(next_token)
72
+ input_seq = torch.tensor(generated, dtype=torch.long).unsqueeze(0)
73
+
74
+ # Convert numerical tokens back to words
75
+ inv_vocab = {idx: word for word, idx in vocab.items()}
76
+ generated_tokens = [inv_vocab.get(tok, '<UNK>') for tok in generated]
77
+ return ' '.join(generated_tokens)
78
+
79
+ # User Inputs
80
+ prompt = st.text_input("Enter your prompt:", "")
81
+ delay = st.slider("Select thinking delay (seconds):", min_value=1, max_value=10, value=3)
82
+
83
+ if st.button("Generate"):
84
+ if not model:
85
+ st.error("Model is not loaded. Please check the model path.")
86
+ elif prompt.strip() == "":
87
+ st.warning("Please enter a prompt to generate text.")
88
+ else:
89
+ with st.spinner("Thinking..."):
90
+ time.sleep(delay)
91
+ response = generate_text(prompt)
92
+ st.success("Here's the generated text:")
93
+ st.write(response)