Spaces:
Runtime error
Runtime error
import pandas as pd | |
import numpy as np | |
import gradio as gr | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, AutoModel | |
import torch | |
from sentence_transformers import SentenceTransformer, util | |
from tqdm import tqdm | |
from peewee import SqliteDatabase, Model, TextField | |
# Load the Enron Email Dataset | |
emails_df = pd.read_csv("/content/emails.csv") | |
# Define the ChromaDB database | |
db = SqliteDatabase('email_embeddings.db') | |
# Define the model and tokenizer | |
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
model = GPT2LMHeadModel.from_pretrained("gpt2") | |
# Fine-tune the Language Model on the dataset | |
# Fine-tune the Language Model on the dataset | |
# Tokenize the dataset | |
tokenized_texts = [tokenizer.encode(text, return_tensors="pt", max_length=512) for text in emails_df['Message']] | |
# Convert tokenized texts to tensors | |
input_ids = torch.cat(tokenized_texts, dim=1) | |
# Define a PyTorch dataset | |
dataset = torch.utils.data.TensorDataset(input_ids) | |
# Define the Sentence Transformer model | |
sentence_model = SentenceTransformer('distilbert-base-nli-mean-tokens') | |
# Function to create embeddings of the email dataset and store them in the ChromaDB database | |
def create_embeddings(): | |
db.connect() | |
db.create_tables([Email]) | |
embeddings = [] | |
for index, row in tqdm(emails_df.iterrows(), total=len(emails_df)): | |
text = row['Message'] | |
embeddings.append(sentence_model.encode(text)) | |
for index, embedding in enumerate(embeddings): | |
Email.create(id=index, embedding=embedding.tobytes()) | |
db.close() | |
# Define the Gradio Interface | |
def answer_question(question): | |
# Encode the question | |
inputs = tokenizer(question, return_tensors="pt", max_length=512, truncation=True) | |
# Generate response using the model | |
outputs = model.generate(**inputs) | |
# Decode the response | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response | |
# Define the Peewee Model for the ChromaDB database | |
class BaseModel(Model): | |
class Meta: | |
database = db | |
class Email(BaseModel): | |
embedding = TextField() | |
# Create a Gradio Interface | |
gr.Interface(fn=answer_question, inputs="text", outputs="text").launch() | |
# Uncomment the line below to create embeddings and store them in the ChromaDB database | |
create_embeddings() | |