azsalihu commited on
Commit
b15e74b
·
verified ·
1 Parent(s): 7de4f02

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import gradio as gr
4
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, AutoModel
5
+ import torch
6
+ from sentence_transformers import SentenceTransformer, util
7
+ from tqdm import tqdm
8
+ from peewee import SqliteDatabase, Model, TextField
9
+
10
+ # Load the Enron Email Dataset
11
+ emails_df = pd.read_csv("/content/emails.csv")
12
+
13
+ # Define the ChromaDB database
14
+ db = SqliteDatabase('email_embeddings.db')
15
+
16
+ # Define the model and tokenizer
17
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
18
+ model = GPT2LMHeadModel.from_pretrained("gpt2")
19
+
20
+ # Fine-tune the Language Model on the dataset
21
+ # Fine-tune the Language Model on the dataset
22
+ # Tokenize the dataset
23
+ tokenized_texts = [tokenizer.encode(text, return_tensors="pt", max_length=512) for text in emails_df['Message']]
24
+
25
+ # Convert tokenized texts to tensors
26
+ input_ids = torch.cat(tokenized_texts, dim=1)
27
+
28
+ # Define a PyTorch dataset
29
+ dataset = torch.utils.data.TensorDataset(input_ids)
30
+
31
+
32
+ # Define the Sentence Transformer model
33
+ sentence_model = SentenceTransformer('distilbert-base-nli-mean-tokens')
34
+
35
+ # Function to create embeddings of the email dataset and store them in the ChromaDB database
36
+ def create_embeddings():
37
+ db.connect()
38
+ db.create_tables([Email])
39
+
40
+ embeddings = []
41
+
42
+ for index, row in tqdm(emails_df.iterrows(), total=len(emails_df)):
43
+ text = row['Message']
44
+ embeddings.append(sentence_model.encode(text))
45
+
46
+ for index, embedding in enumerate(embeddings):
47
+ Email.create(id=index, embedding=embedding.tobytes())
48
+
49
+ db.close()
50
+
51
+ # Define the Gradio Interface
52
+ def answer_question(question):
53
+ # Encode the question
54
+ inputs = tokenizer(question, return_tensors="pt", max_length=512, truncation=True)
55
+
56
+ # Generate response using the model
57
+ outputs = model.generate(**inputs)
58
+
59
+ # Decode the response
60
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
61
+
62
+ return response
63
+
64
+ # Define the Peewee Model for the ChromaDB database
65
+ class BaseModel(Model):
66
+ class Meta:
67
+ database = db
68
+
69
+ class Email(BaseModel):
70
+ embedding = TextField()
71
+
72
+ # Create a Gradio Interface
73
+ gr.Interface(fn=answer_question, inputs="text", outputs="text").launch()
74
+
75
+ # Uncomment the line below to create embeddings and store them in the ChromaDB database
76
+ create_embeddings()