shakhovak commited on
Commit
706771c
1 Parent(s): fe6f530

added files

Browse files
Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9.13
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt /app
6
+
7
+ RUN pip install --no-cache-dir -r requirements.txt
8
+
9
+ COPY . /app
10
+
11
+ RUN useradd -m -u 1000 user
12
+
13
+ USER user
14
+
15
+ ENV HOME=/home/user \
16
+ PATH=/home/user/.local/bin:$PATH
17
+
18
+ WORKDIR $HOME/app
19
+
20
+
21
+ COPY --chown=user . $HOME/app
22
+
23
+ #CMD ["gunicorn", "--timeout", "1000", "app:app", "-b", "0.0.0.0:5000"]
24
+ #CMD ["python", "app.py"]
25
+ CMD ["gunicorn", "--timeout", "1000", "--workers", "2", "--worker-class", "gevent", "--worker-connections" , "100", "app:app", "-b", "0.0.0.0:7860"]
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request
2
+ from generate_bot import ChatBot
3
+ import asyncio
4
+
5
+ app = Flask(__name__)
6
+ chatSheldon = ChatBot()
7
+ chatSheldon.load()
8
+
9
+ # this script is running flask application
10
+
11
+
12
+ @app.route("/")
13
+ async def index():
14
+ return render_template("chat.html")
15
+
16
+
17
+ async def sleep():
18
+ await asyncio.sleep(0.1)
19
+ return 0.1
20
+
21
+
22
+ @app.route("/get", methods=["GET", "POST"])
23
+ async def chat():
24
+ msg = request.form["msg"]
25
+ input = msg
26
+ await asyncio.gather(sleep(), sleep())
27
+ return get_Chat_response(input)
28
+
29
+
30
+ def get_Chat_response(text):
31
+ answer = chatSheldon.generate_response(text)
32
+ return answer
33
+
34
+
35
+ if __name__ == "__main__":
36
+ app.run(debug=True, host="0.0.0.0")
data/scripts.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02be23abb73d94025637264be0813338fba80a81eb1a95074f3437d61392cc73
3
+ size 2099433
data/scripts_reworked.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c0329681a10e682750117eb86c9eace5ef79af5e1c113f0af383ef814bba405
3
+ size 7686225
data/scripts_vectors.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1865b58f9fc16255786cfee8be7f6c120e3986ae1dc7012e07d1cee9f77bdb77
3
+ size 67336895
generate_bot.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ from sentence_transformers import SentenceTransformer
4
+ from utils import generate_response
5
+ import pandas as pd
6
+ import pickle
7
+ from utils import encode_rag, cosine_sim_rag, top_candidates
8
+
9
+
10
+ class ChatBot:
11
+ def __init__(self):
12
+ self.conversation_history = deque([], maxlen=10)
13
+ self.generative_model = None
14
+ self.generative_tokenizer = None
15
+ self.vect_data = []
16
+ self.scripts = []
17
+ self.ranking_model = None
18
+
19
+ def load(self):
20
+ """ "This method is called first to load all datasets and
21
+ model used by the chat bot; all the data to be saved in
22
+ tha data folder, models to be loaded from hugging face"""
23
+
24
+ with open("data/scripts_vectors.pkl", "rb") as fp:
25
+ self.vect_data = pickle.load(fp)
26
+ self.scripts = pd.read_pickle("data/scripts.pkl")
27
+ self.ranking_model = SentenceTransformer(
28
+ "Shakhovak/chatbot_sentence-transformer"
29
+ )
30
+ self.generative_model = AutoModelForSeq2SeqLM.from_pretrained(
31
+ "Shakhovak/flan-t5-base-sheldon-chat-v2"
32
+ )
33
+ self.generative_tokenizer = AutoTokenizer.from_pretrained(
34
+ "Shakhovak/flan-t5-base-sheldon-chat-v2"
35
+ )
36
+
37
+ def generate_response(self, utterance):
38
+
39
+ query_encoding = encode_rag(
40
+ texts=utterance,
41
+ model=self.ranking_model,
42
+ contexts=self.conversation_history,
43
+ )
44
+
45
+ bot_cosine_scores = cosine_sim_rag(
46
+ self.vect_data,
47
+ query_encoding,
48
+ )
49
+
50
+ top_scores, top_indexes = top_candidates(
51
+ bot_cosine_scores, initial_data=self.scripts
52
+ )
53
+
54
+ if top_scores[0] >= 0.89:
55
+ for index in top_indexes:
56
+ rag_answer = self.scripts.iloc[index]["answer"]
57
+
58
+ answer = generate_response(
59
+ model=self.generative_model,
60
+ tokenizer=self.generative_tokenizer,
61
+ question=utterance,
62
+ context=self.conversation_history,
63
+ top_p=0.9,
64
+ temperature=0.95,
65
+ rag_answer=rag_answer,
66
+ )
67
+ else:
68
+ answer = generate_response(
69
+ model=self.generative_model,
70
+ tokenizer=self.generative_tokenizer,
71
+ question=utterance,
72
+ context=self.conversation_history,
73
+ top_p=0.9,
74
+ temperature=0.95,
75
+ )
76
+
77
+ self.conversation_history.append(utterance)
78
+ self.conversation_history.append(answer)
79
+ return answer
80
+
81
+
82
+ # katya = ChatBot()
83
+ # katya.load()
84
+ # print(katya.generate_response("What is he doing there?"))
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pandas==2.2.1
2
+ flask[async]==3.0.2
3
+ datasets==2.17.1
4
+ transformers==4.38.1
5
+ gunicorn==21.2.0
6
+ gevent>=1.4
7
+ requests==2.31.0
8
+ scikit-learn==1.4.1.post1
9
+ scipy==1.12.0
10
+ numpy==1.26.4
11
+ torch==2.2.1
12
+ sentence-transformers==2.3.1
static/style.css ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ body,html{
2
+ height: 100%;
3
+ margin: 0;
4
+ background: rgb(44, 47, 59);
5
+ background: -webkit-linear-gradient(to right, rgb(40, 59, 34), rgb(54, 60, 70), rgb(32, 32, 43));
6
+ background: linear-gradient(to right, rgb(38, 51, 61), rgb(50, 55, 65), rgb(33, 33, 78));
7
+ }
8
+
9
+ .chat{
10
+ margin-top: auto;
11
+ margin-bottom: auto;
12
+ }
13
+ .card{
14
+ height: 500px;
15
+ border-radius: 15px !important;
16
+ background-color: rgba(0,0,0,0.4) !important;
17
+ }
18
+ .contacts_body{
19
+ padding: 0.75rem 0 !important;
20
+ overflow-y: auto;
21
+ white-space: nowrap;
22
+ }
23
+ .msg_card_body{
24
+ overflow-y: auto;
25
+ }
26
+ .card-header{
27
+ border-radius: 15px 15px 0 0 !important;
28
+ border-bottom: 0 !important;
29
+ }
30
+ .card-footer{
31
+ border-radius: 0 0 15px 15px !important;
32
+ border-top: 0 !important;
33
+ }
34
+ .container{
35
+ align-content: center;
36
+ }
37
+ .search{
38
+ border-radius: 15px 0 0 15px !important;
39
+ background-color: rgba(0,0,0,0.3) !important;
40
+ border:0 !important;
41
+ color:white !important;
42
+ }
43
+ .search:focus{
44
+ box-shadow:none !important;
45
+ outline:0px !important;
46
+ }
47
+ .type_msg{
48
+ background-color: rgba(0,0,0,0.3) !important;
49
+ border:0 !important;
50
+ color:white !important;
51
+ height: 60px !important;
52
+ overflow-y: auto;
53
+ }
54
+ .type_msg:focus{
55
+ box-shadow:none !important;
56
+ outline:0px !important;
57
+ }
58
+ .attach_btn{
59
+ border-radius: 15px 0 0 15px !important;
60
+ background-color: rgba(0,0,0,0.3) !important;
61
+ border:0 !important;
62
+ color: white !important;
63
+ cursor: pointer;
64
+ }
65
+ .send_btn{
66
+ border-radius: 0 15px 15px 0 !important;
67
+ background-color: rgba(0,0,0,0.3) !important;
68
+ border:0 !important;
69
+ color: white !important;
70
+ cursor: pointer;
71
+ }
72
+ .search_btn{
73
+ border-radius: 0 15px 15px 0 !important;
74
+ background-color: rgba(0,0,0,0.3) !important;
75
+ border:0 !important;
76
+ color: white !important;
77
+ cursor: pointer;
78
+ }
79
+ .contacts{
80
+ list-style: none;
81
+ padding: 0;
82
+ }
83
+ .contacts li{
84
+ width: 100% !important;
85
+ padding: 5px 10px;
86
+ margin-bottom: 15px !important;
87
+ }
88
+ .active{
89
+ background-color: rgba(0,0,0,0.3);
90
+ }
91
+ .user_img{
92
+ height: 70px;
93
+ width: 70px;
94
+ border:1.5px solid #f5f6fa;
95
+
96
+ }
97
+ .user_img_msg{
98
+ height: 40px;
99
+ width: 40px;
100
+ border:1.5px solid #f5f6fa;
101
+
102
+ }
103
+ .img_cont{
104
+ position: relative;
105
+ height: 70px;
106
+ width: 70px;
107
+ }
108
+ .img_cont_msg{
109
+ height: 40px;
110
+ width: 40px;
111
+ }
112
+ .online_icon{
113
+ position: absolute;
114
+ height: 15px;
115
+ width:15px;
116
+ background-color: #4cd137;
117
+ border-radius: 50%;
118
+ bottom: 0.2em;
119
+ right: 0.4em;
120
+ border:1.5px solid white;
121
+ }
122
+ .offline{
123
+ background-color: #c23616 !important;
124
+ }
125
+ .user_info{
126
+ margin-top: auto;
127
+ margin-bottom: auto;
128
+ margin-left: 15px;
129
+ }
130
+ .user_info span{
131
+ font-size: 20px;
132
+ color: white;
133
+ }
134
+ .user_info p{
135
+ font-size: 10px;
136
+ color: rgba(255,255,255,0.6);
137
+ }
138
+ .video_cam{
139
+ margin-left: 50px;
140
+ margin-top: 5px;
141
+ }
142
+ .video_cam span{
143
+ color: white;
144
+ font-size: 20px;
145
+ cursor: pointer;
146
+ margin-right: 20px;
147
+ }
148
+ .msg_cotainer{
149
+ margin-top: auto;
150
+ margin-bottom: auto;
151
+ margin-left: 10px;
152
+ border-radius: 25px;
153
+ background-color: rgb(82, 172, 255);
154
+ padding: 10px;
155
+ position: relative;
156
+ }
157
+ .msg_cotainer_send{
158
+ margin-top: auto;
159
+ margin-bottom: auto;
160
+ margin-right: 10px;
161
+ border-radius: 25px;
162
+ background-color: #58cc71;
163
+ padding: 10px;
164
+ position: relative;
165
+ }
166
+ .msg_time{
167
+ position: absolute;
168
+ left: 0;
169
+ bottom: -15px;
170
+ color: rgba(255,255,255,0.5);
171
+ font-size: 10px;
172
+ }
173
+ .msg_time_send{
174
+ position: absolute;
175
+ right:0;
176
+ bottom: -15px;
177
+ color: rgba(255,255,255,0.5);
178
+ font-size: 10px;
179
+ }
180
+ .msg_head{
181
+ position: relative;
182
+ }
183
+ #action_menu_btn{
184
+ position: absolute;
185
+ right: 10px;
186
+ top: 10px;
187
+ color: white;
188
+ cursor: pointer;
189
+ font-size: 20px;
190
+ }
191
+ .action_menu{
192
+ z-index: 1;
193
+ position: absolute;
194
+ padding: 15px 0;
195
+ background-color: rgba(0,0,0,0.5);
196
+ color: white;
197
+ border-radius: 15px;
198
+ top: 30px;
199
+ right: 15px;
200
+ display: none;
201
+ }
202
+ .action_menu ul{
203
+ list-style: none;
204
+ padding: 0;
205
+ margin: 0;
206
+ }
207
+ .action_menu ul li{
208
+ width: 100%;
209
+ padding: 10px 15px;
210
+ margin-bottom: 5px;
211
+ }
212
+ .action_menu ul li i{
213
+ padding-right: 10px;
214
+ }
215
+ .action_menu ul li:hover{
216
+ cursor: pointer;
217
+ background-color: rgba(0,0,0,0.2);
218
+ }
219
+ @media(max-width: 576px){
220
+ .contacts_card{
221
+ margin-bottom: 15px !important;
222
+ }
223
+ }
templates/chat.html ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <link href="//maxcdn.bootstrapcdn.com/bootstrap/4.1.1/css/bootstrap.min.css" rel="stylesheet" id="bootstrap-css">
2
+ <script src="//maxcdn.bootstrapcdn.com/bootstrap/4.1.1/js/bootstrap.min.js"></script>
3
+ <script src="//cdnjs.cloudflare.com/ajax/libs/jquery/3.2.1/jquery.min.js"></script>
4
+
5
+ <!DOCTYPE html>
6
+ <html>
7
+ <head>
8
+ <title>Chatbot</title>
9
+ <link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.1.3/css/bootstrap.min.css" integrity="sha384-MCw98/SFnGE8fJT3GXwEOngsV7Zt27NXFoaoApmYm81iuXoPkFOJwJ8ERdknLPMO" crossorigin="anonymous">
10
+ <link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.5.0/css/all.css" integrity="sha384-B4dIYHKNBt8Bc12p+WXckhzcICo0wtJAoU8YZTY5qE0Id1GSseTk6S+L3BlXeVIU" crossorigin="anonymous">
11
+ <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.3.1/jquery.min.js"></script>
12
+ <link rel="stylesheet" type="text/css" href="{{ url_for('static', filename='style.css')}}"/>
13
+ </head>
14
+
15
+
16
+ <body>
17
+ <div class="container-fluid h-100">
18
+ <div class="row justify-content-center h-100">
19
+ <div class="col-md-8 col-xl-6 chat">
20
+ <div class="card">
21
+ <div class="card-header msg_head">
22
+ <div class="d-flex bd-highlight">
23
+ <div class="img_cont">
24
+ <img src="https://stickerpacks.ru/wp-content/uploads/2023/04/nabor-stikerov-teorija-bolshogo-vzryva-5-dlja-telegram-3.webp" class="rounded-circle user_img">
25
+ <span class="online_icon"></span>
26
+ </div>
27
+ <div class="user_info">
28
+ <span>ChatBot</span>
29
+ <p>Ask me anything!</p>
30
+ </div>
31
+ </div>
32
+ </div>
33
+ <div id="messageFormeight" class="card-body msg_card_body">
34
+
35
+
36
+ </div>
37
+ <div class="card-footer">
38
+ <form id="messageArea" class="input-group">
39
+ <input type="text" id="text" name="msg" placeholder="Type your message..." autocomplete="off" class="form-control type_msg" required/>
40
+ <div class="input-group-append">
41
+ <button type="submit" id="send" class="input-group-text send_btn"><i class="fas fa-location-arrow"></i></button>
42
+ </div>
43
+ </form>
44
+ </div>
45
+ </div>
46
+ </div>
47
+ </div>
48
+ </div>
49
+
50
+ <script>
51
+ $(document).ready(function() {
52
+ $("#messageArea").on("submit", function(event) {
53
+ const date = new Date();
54
+ const hour = date.getHours();
55
+ const minute = date.getMinutes();
56
+ const str_time = hour+":"+minute;
57
+ var rawText = $("#text").val();
58
+
59
+ var userHtml = '<div class="d-flex justify-content-end mb-4"><div class="msg_cotainer_send">' + rawText + '<span class="msg_time_send">'+ str_time + '</span></div><div class="img_cont_msg"><img src="https://i.ibb.co/d5b84Xw/Untitled-design.png" class="rounded-circle user_img_msg"></div></div>';
60
+
61
+ $("#text").val("");
62
+ $("#messageFormeight").append(userHtml);
63
+
64
+ $.ajax({
65
+ data: {
66
+ msg: rawText,
67
+ },
68
+ type: "POST",
69
+ url: "/get",
70
+ }).done(function(data) {
71
+ var botHtml = '<div class="d-flex justify-content-start mb-4"><div class="img_cont_msg"><img src="https://stickerpacks.ru/wp-content/uploads/2023/04/nabor-stikerov-teorija-bolshogo-vzryva-5-dlja-telegram-3.webp" class="rounded-circle user_img_msg"></div><div class="msg_cotainer">' + data + '<span class="msg_time">' + str_time + '</span></div></div>';
72
+ $("#messageFormeight").append($.parseHTML(botHtml));
73
+ });
74
+ event.preventDefault();
75
+ });
76
+ });
77
+ </script>
78
+
79
+ </body>
80
+ </html>
utils.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from sklearn.metrics.pairwise import cosine_similarity
4
+ from scipy import sparse
5
+ import pickle
6
+
7
+
8
+ def scripts_rework(path, character):
9
+ """FOR GENERARTIVE MODEL TRAINING!!!
10
+ this functions split scripts for question, answer, context,
11
+ picks up the character, augments data for generative model training
12
+ and saves data in pickle format"""
13
+
14
+ df = pd.read_csv(path)
15
+
16
+ # split data for scenes
17
+ count = 0
18
+ df["scene_count"] = ""
19
+ for index, row in df.iterrows():
20
+ if index == 0:
21
+ df.iloc[index]["scene_count"] = count
22
+ elif row["person_scene"] == "Scene":
23
+ count += 1
24
+ df.iloc[index]["scene_count"] = count
25
+ else:
26
+ df.iloc[index]["scene_count"] = count
27
+
28
+ df = df.dropna().reset_index()
29
+
30
+ # rework scripts to filer by caracter utterances and related context
31
+ scripts = pd.DataFrame()
32
+ for index, row in df.iterrows():
33
+ if (row["person_scene"] == character) & (
34
+ df.iloc[index - 1]["person_scene"] != "Scene"
35
+ ):
36
+ context = []
37
+
38
+ for i in reversed(range(2, 6)):
39
+ if (df.iloc[index - i]["person_scene"] != "Scene") & (index - i >= 0):
40
+ context.append(df.iloc[index - i]["dialogue"])
41
+ else:
42
+ break
43
+
44
+ for j in range(len(context)):
45
+ new_row = {
46
+ "answer": row["dialogue"],
47
+ "question": df.iloc[index - 1]["dialogue"],
48
+ "context": context[j:],
49
+ }
50
+ scripts = pd.concat([scripts, pd.DataFrame([new_row])])
51
+ new_row = {
52
+ "answer": row["dialogue"],
53
+ "question": df.iloc[index - 1]["dialogue"],
54
+ "context": [],
55
+ }
56
+ scripts = pd.concat([scripts, pd.DataFrame([new_row])])
57
+
58
+ elif (row["person_scene"] == character) & (
59
+ df.iloc[index - 1]["person_scene"] == "Scene"
60
+ ):
61
+ context = []
62
+ new_row = {"answer": row["dialogue"], "question": "", "context": context}
63
+ scripts = pd.concat([scripts, pd.DataFrame([new_row])])
64
+ # load reworked data to pkl
65
+ scripts = scripts[scripts["question"] != ""]
66
+ scripts["context"] = scripts["context"].apply(lambda x: "".join(x))
67
+ scripts = scripts.reset_index(drop=True)
68
+ scripts.to_pickle("data/scripts_reworked.pkl")
69
+
70
+
71
+ # ===================================================
72
+ def scripts_rework_ranking(path, character):
73
+ """FOR RAG RETRIEVAL !!!!
74
+ this functions split scripts for queation, answer, context,
75
+ picks up the cahracter and saves data in pickle format"""
76
+
77
+ df = pd.read_csv(path)
78
+
79
+ # split data for scenes
80
+ count = 0
81
+ df["scene_count"] = ""
82
+ for index, row in df.iterrows():
83
+ if index == 0:
84
+ df.iloc[index]["scene_count"] = count
85
+ elif row["person_scene"] == "Scene":
86
+ count += 1
87
+ df.iloc[index]["scene_count"] = count
88
+ else:
89
+ df.iloc[index]["scene_count"] = count
90
+
91
+ df = df.dropna().reset_index()
92
+
93
+ # rework scripts to filer by caracter utterances and related context
94
+ scripts = pd.DataFrame()
95
+ for index, row in df.iterrows():
96
+ if (row["person_scene"] == character) & (
97
+ df.iloc[index - 1]["person_scene"] != "Scene"
98
+ ):
99
+ context = []
100
+ for i in reversed(range(2, 5)):
101
+ if (df.iloc[index - i]["person_scene"] != "Scene") & (index - i >= 0):
102
+ context.append(df.iloc[index - i]["dialogue"])
103
+ else:
104
+ break
105
+ new_row = {
106
+ "answer": row["dialogue"],
107
+ "question": df.iloc[index - 1]["dialogue"],
108
+ "context": context,
109
+ }
110
+
111
+ scripts = pd.concat([scripts, pd.DataFrame([new_row])])
112
+
113
+ elif (row["person_scene"] == character) & (
114
+ df.iloc[index - 1]["person_scene"] == "Scene"
115
+ ):
116
+ context = []
117
+ new_row = {"answer": row["dialogue"], "question": "", "context": context}
118
+ scripts = pd.concat([scripts, pd.DataFrame([new_row])])
119
+ # load reworked data to pkl
120
+ scripts = scripts[scripts["question"] != ""]
121
+ scripts = scripts.reset_index(drop=True)
122
+ scripts.to_pickle("data/scripts.pkl")
123
+
124
+
125
+ # ===================================================
126
+ def encode(texts, model, contexts=None, do_norm=True):
127
+ """function to encode texts for cosine similarity search"""
128
+
129
+ question_vectors = model.encode(texts)
130
+ if type(contexts) is list:
131
+ context_vectors = model.encode("".join(contexts))
132
+ else:
133
+ context_vectors = model.encode(contexts)
134
+
135
+ return np.concatenate(
136
+ [
137
+ np.asarray(context_vectors),
138
+ np.asarray(question_vectors),
139
+ ],
140
+ axis=-1,
141
+ )
142
+
143
+
144
+ def encode_rag(texts, model, contexts=None, do_norm=True):
145
+ """function to encode texts for cosine similarity search"""
146
+
147
+ question_vectors = model.encode(texts)
148
+ context_vectors = model.encode("".join(contexts))
149
+
150
+ return np.concatenate(
151
+ [
152
+ np.asarray(context_vectors),
153
+ np.asarray(question_vectors),
154
+ ],
155
+ axis=-1,
156
+ )
157
+
158
+
159
+ # ===================================================
160
+ def encode_df_save(model):
161
+ """FOR RAG RETRIEVAL DATABASE
162
+ this functions vectorizes reworked scripts and loads them to
163
+ pickle file to be used as retrieval base for ranking script"""
164
+
165
+ scripts_reopened = pd.read_pickle("data/scripts.pkl")
166
+ vect_data = []
167
+ for index, row in scripts_reopened.iterrows():
168
+ if type(row["context"]) is list:
169
+ vect = encode(
170
+ texts=row["question"],
171
+ model=model,
172
+ contexts="".join(row["context"]),
173
+ )
174
+ vect_data.append(vect)
175
+ else:
176
+ vect = encode(
177
+ texts=row["question"],
178
+ model=model,
179
+ contexts=row["context"],
180
+ )
181
+ vect_data.append(vect)
182
+ with open("data/scripts_vectors.pkl", "wb") as f:
183
+ pickle.dump(vect_data, f)
184
+
185
+
186
+ # ===================================================
187
+ def cosine_sim(answer_true_vectros, answer_generated_vectors) -> list:
188
+ """FOR MODEL EVALUATION!!!!
189
+ returns list of tuples with similarity score"""
190
+
191
+ data_emb = sparse.csr_matrix(answer_true_vectros)
192
+ query_emb = sparse.csr_matrix(answer_generated_vectors)
193
+ similarity = cosine_similarity(query_emb, data_emb).flatten()
194
+ return similarity[0]
195
+
196
+
197
+ # ===================================================
198
+ def cosine_sim_rag(data_vectors, query_vectors) -> list:
199
+ """FOR RAG RETRIEVAL RANKS!!!
200
+ returns list of tuples with similarity score and
201
+ script index in initial dataframe"""
202
+
203
+ data_emb = sparse.csr_matrix(data_vectors)
204
+ query_emb = sparse.csr_matrix(query_vectors)
205
+ similarity = cosine_similarity(query_emb, data_emb).flatten()
206
+ ind = np.argwhere(similarity)
207
+ match = sorted(zip(similarity, ind.tolist()), reverse=True)
208
+
209
+ return match
210
+
211
+
212
+ # ===================================================
213
+ def generate_response(
214
+ model,
215
+ tokenizer,
216
+ question,
217
+ context,
218
+ top_p,
219
+ temperature,
220
+ rag_answer="",
221
+ ):
222
+
223
+ combined = (
224
+ "context:" + rag_answer +
225
+ "".join(context) + "</s>" +
226
+ "question: " + question
227
+ )
228
+ input_ids = tokenizer.encode(combined, return_tensors="pt")
229
+ sample_output = model.generate(
230
+ input_ids,
231
+ do_sample=True,
232
+ max_length=1000,
233
+ top_p=top_p,
234
+ temperature=temperature,
235
+ repetition_penalty=2.0,
236
+ top_k=50,
237
+ no_repeat_ngram_size=4,
238
+ # early_stopping=True,
239
+ # min_length=10,
240
+ )
241
+
242
+ out = tokenizer.decode(sample_output[0][1:], skip_special_tokens=True)
243
+ if "</s>" in out:
244
+ out = out[: out.find("</s>")].strip()
245
+
246
+ return out
247
+
248
+
249
+ # ===================================================
250
+ def top_candidates(score_lst_sorted, initial_data, top=1):
251
+ """this functions receives results of the cousine similarity ranking and
252
+ returns top items' scores and their indices"""
253
+
254
+ scores = [item[0] for item in score_lst_sorted]
255
+ candidates_indexes = [item[1][0] for item in score_lst_sorted]
256
+ return scores[0:top], candidates_indexes[0:top]