Nutmcg / App.py
TejAndrewsACC's picture
Update App.py
3fb0242 verified
raw
history blame
5.06 kB
import torch
import torch.nn as nn
import random
import pickle
import gradio as gr
import numpy as np
import torch.nn.functional as F
import string
# ---- Memory Management ----
session_memory = []
def save_memory(memory, filename='chat_memory.pkl'):
with open(filename, 'wb') as f:
pickle.dump(memory, f)
def load_memory(filename='chat_memory.pkl'):
try:
with open(filename, 'rb') as f:
return pickle.load(f)
except (FileNotFoundError, EOFError):
return [] # Return an empty list if the file is empty or doesn't exist
session_memory = load_memory()
# ---- Character-Level RNN Model ----
class CharRNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(CharRNN, self).__init__()
self.hidden_size = hidden_size
self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x, hidden):
out, hidden = self.rnn(x, hidden)
out = self.fc(out[:, -1, :]) # Use last time-step
return out, hidden
def init_hidden(self, batch_size):
return torch.zeros(batch_size, self.hidden_size).to(device)
# ---- PHI Model ----
class PHIModel(nn.Module):
def __init__(self, input_size, output_size):
super(PHIModel, self).__init__()
self.phi = (1 + np.sqrt(5)) / 2 # Golden Ratio
self.fc1 = nn.Linear(input_size, int(input_size * self.phi))
self.fc2 = nn.Linear(int(input_size * self.phi), output_size)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# ---- Helper Functions ----
# Generate a sequence of characters as a response to the input
def generate_response_rnn(model, input_text, char_to_idx, idx_to_char, max_len=100):
# Convert input text to tensor
input_tensor = torch.tensor([char_to_idx[c] for c in input_text], dtype=torch.long).unsqueeze(0).to(device)
hidden = model.init_hidden(1)
output_str = input_text
# Generate characters one at a time
for _ in range(max_len):
output, hidden = model(input_tensor, hidden)
prob = F.softmax(output, dim=1)
predicted_idx = torch.multinomial(prob, 1).item()
predicted_char = idx_to_char[predicted_idx]
output_str += predicted_char
input_tensor = torch.tensor([[predicted_idx]], dtype=torch.long).to(device)
return output_str
# ---- Training Data ----
def prepare_data(text):
# Create a set of all unique characters and map them to indices
chars = sorted(list(set(text)))
char_to_idx = {char: idx for idx, char in enumerate(chars)}
idx_to_char = {idx: char for idx, char in enumerate(chars)}
return char_to_idx, idx_to_char
# ---- Chat Interface ----
def simple_chat(user_input):
session_memory.append({"input": user_input})
save_memory(session_memory)
# Training data (for simplicity, using a sample text)
sample_text = "hello there, how can I assist you today?"
char_to_idx, idx_to_char = prepare_data(sample_text)
# Initialize the RNN model with appropriate input/output sizes
input_size = len(char_to_idx)
hidden_size = 128 # Arbitrary size for hidden layer
output_size = len(char_to_idx)
# Create and load the RNN model
model = CharRNN(input_size, hidden_size, output_size).to(device)
# Load pre-trained weights (here using a dummy initialization for illustration)
# In a real case, you would load weights from a trained model
model.load_state_dict(torch.load('char_rnn_model.pth', map_location=device))
model.eval()
# Generate a response using the model
response = generate_response_rnn(model, user_input, char_to_idx, idx_to_char)
return response
# ---- Gradio Interface ----
def chat_interface(user_input):
response = simple_chat(user_input)
return response
# ---- Gradio App Setup ----
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
with gr.Blocks() as app:
gr.Markdown("# **Chatbot with Neural Network and Text Generation**")
with gr.Row():
with gr.Column(scale=1):
user_input = gr.Textbox(label="What will you say?", placeholder="Type something here...")
submit_button = gr.Button("Send")
with gr.Column(scale=1):
chatbot = gr.Textbox(label="Chatbot Response", interactive=False) # This is now a Textbox for output
# Adding custom styling for the UI
gr.HTML("""
<style>
.gradio-container {
background-color: #F0F8FF;
padding: 20px;
border-radius: 15px;
font-family: 'Arial';
}
.gradio-row {
display: flex;
justify-content: space-between;
}
</style>
""")
# Setting the button click event
submit_button.click(chat_interface, inputs=user_input, outputs=chatbot)
# Launch the Gradio app
app.launch()