ENRON_EMAILS_QA / app.py
azsalihu's picture
Create app.py
b15e74b verified
raw
history blame
2.31 kB
import pandas as pd
import numpy as np
import gradio as gr
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, AutoModel
import torch
from sentence_transformers import SentenceTransformer, util
from tqdm import tqdm
from peewee import SqliteDatabase, Model, TextField
# Load the Enron Email Dataset
emails_df = pd.read_csv("/content/emails.csv")
# Define the ChromaDB database
db = SqliteDatabase('email_embeddings.db')
# Define the model and tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
# Fine-tune the Language Model on the dataset
# Fine-tune the Language Model on the dataset
# Tokenize the dataset
tokenized_texts = [tokenizer.encode(text, return_tensors="pt", max_length=512) for text in emails_df['Message']]
# Convert tokenized texts to tensors
input_ids = torch.cat(tokenized_texts, dim=1)
# Define a PyTorch dataset
dataset = torch.utils.data.TensorDataset(input_ids)
# Define the Sentence Transformer model
sentence_model = SentenceTransformer('distilbert-base-nli-mean-tokens')
# Function to create embeddings of the email dataset and store them in the ChromaDB database
def create_embeddings():
db.connect()
db.create_tables([Email])
embeddings = []
for index, row in tqdm(emails_df.iterrows(), total=len(emails_df)):
text = row['Message']
embeddings.append(sentence_model.encode(text))
for index, embedding in enumerate(embeddings):
Email.create(id=index, embedding=embedding.tobytes())
db.close()
# Define the Gradio Interface
def answer_question(question):
# Encode the question
inputs = tokenizer(question, return_tensors="pt", max_length=512, truncation=True)
# Generate response using the model
outputs = model.generate(**inputs)
# Decode the response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
# Define the Peewee Model for the ChromaDB database
class BaseModel(Model):
class Meta:
database = db
class Email(BaseModel):
embedding = TextField()
# Create a Gradio Interface
gr.Interface(fn=answer_question, inputs="text", outputs="text").launch()
# Uncomment the line below to create embeddings and store them in the ChromaDB database
create_embeddings()