ENRON_EMAILS_QA / app.py
azsalihu's picture
Create app.py
b15e74b verified
import pandas as pd
import numpy as np
import gradio as gr
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, AutoModel
import torch
from sentence_transformers import SentenceTransformer, util
from tqdm import tqdm
from peewee import SqliteDatabase, Model, TextField
# Load the Enron Email Dataset
emails_df = pd.read_csv("/content/emails.csv")
# Define the ChromaDB database
db = SqliteDatabase('email_embeddings.db')
# Define the model and tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
# Fine-tune the Language Model on the dataset
# Fine-tune the Language Model on the dataset
# Tokenize the dataset
tokenized_texts = [tokenizer.encode(text, return_tensors="pt", max_length=512) for text in emails_df['Message']]
# Convert tokenized texts to tensors
input_ids = torch.cat(tokenized_texts, dim=1)
# Define a PyTorch dataset
dataset = torch.utils.data.TensorDataset(input_ids)
# Define the Sentence Transformer model
sentence_model = SentenceTransformer('distilbert-base-nli-mean-tokens')
# Function to create embeddings of the email dataset and store them in the ChromaDB database
def create_embeddings():
db.connect()
db.create_tables([Email])
embeddings = []
for index, row in tqdm(emails_df.iterrows(), total=len(emails_df)):
text = row['Message']
embeddings.append(sentence_model.encode(text))
for index, embedding in enumerate(embeddings):
Email.create(id=index, embedding=embedding.tobytes())
db.close()
# Define the Gradio Interface
def answer_question(question):
# Encode the question
inputs = tokenizer(question, return_tensors="pt", max_length=512, truncation=True)
# Generate response using the model
outputs = model.generate(**inputs)
# Decode the response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
# Define the Peewee Model for the ChromaDB database
class BaseModel(Model):
class Meta:
database = db
class Email(BaseModel):
embedding = TextField()
# Create a Gradio Interface
gr.Interface(fn=answer_question, inputs="text", outputs="text").launch()
# Uncomment the line below to create embeddings and store them in the ChromaDB database
create_embeddings()