JustKiddo commited on
Commit
7f79d8b
·
verified ·
1 Parent(s): dc14176

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -52
app.py CHANGED
@@ -4,54 +4,79 @@ import torch
4
  from transformers import AutoTokenizer, AutoModel
5
  import numpy as np
6
  from sklearn.metrics.pairwise import cosine_similarity
7
-
8
- # Get the port from Heroku environment, default to 8501 for local development
9
- PORT = int(os.environ.get('PORT', 8501))
10
-
11
- class LazyLoadModel:
12
- def __init__(self, model_name='intfloat/multilingual-e5-small'):
13
- self.model_name = model_name
14
- self._tokenizer = None
15
- self._model = None
16
-
17
- @property
18
- def tokenizer(self):
19
- if self._tokenizer is None:
20
- print("Loading tokenizer...")
21
- self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
22
- return self._tokenizer
23
-
24
- @property
25
- def model(self):
26
- if self._model is None:
27
- print("Loading model...")
28
- # Use float16 to reduce memory and potentially speed up loading
29
- self._model = AutoModel.from_pretrained(self.model_name, torch_dtype=torch.float16)
30
- return self._model
31
 
32
  class VietnameseChatbot:
33
- def __init__(self):
34
  """
35
- Initialize the Vietnamese chatbot with lazy-loaded model
36
  """
37
- self.model_loader = LazyLoadModel()
 
 
 
 
 
38
 
39
- # Very minimal conversation data to reduce startup time
40
- self.conversation_data = [
41
- {"query": "Xin chào", "response": "Chào bạn!"},
42
- {"query": "Bạn ai?", "response": "Tôi là trợ lý AI."},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  ]
44
-
 
 
 
 
 
 
 
 
 
 
 
45
  def embed_text(self, text):
46
  """
47
  Generate embeddings for input text
48
  """
49
  try:
50
  # Tokenize and generate embeddings
51
- inputs = self.model_loader.tokenizer(text, return_tensors='pt', padding=True, truncation=True)
52
 
53
  with torch.no_grad():
54
- model_output = self.model_loader.model(**inputs)
55
 
56
  # Mean pooling
57
  embeddings = self.mean_pooling(model_output, inputs['attention_mask'])
@@ -59,7 +84,7 @@ class VietnameseChatbot:
59
  except Exception as e:
60
  print(f"Embedding error: {e}")
61
  return None
62
-
63
  def mean_pooling(self, model_output, attention_mask):
64
  """
65
  Perform mean pooling on model output
@@ -67,7 +92,7 @@ class VietnameseChatbot:
67
  token_embeddings = model_output[0]
68
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
69
  return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
70
-
71
  def get_response(self, user_query):
72
  """
73
  Find the most similar response from conversation data
@@ -77,15 +102,10 @@ class VietnameseChatbot:
77
  query_embedding = self.embed_text(user_query)
78
 
79
  if query_embedding is None:
80
- return "Xin lỗi, đã có lỗi xảy ra."
81
-
82
- # Embed conversation data
83
- conversation_embeddings = np.array([
84
- self.embed_text(item['query'])[0] for item in self.conversation_data
85
- ])
86
 
87
  # Calculate cosine similarities
88
- similarities = cosine_similarity(query_embedding, conversation_embeddings)[0]
89
 
90
  # Find most similar response
91
  best_match_index = np.argmax(similarities)
@@ -94,26 +114,33 @@ class VietnameseChatbot:
94
  if similarities[best_match_index] > 0.5:
95
  return self.conversation_data[best_match_index]['response']
96
 
97
- return "Xin lỗi, tôi không hiểu câu hỏi của bạn."
98
  except Exception as e:
99
  print(f"Response generation error: {e}")
100
  return "Đã xảy ra lỗi. Xin vui lòng thử lại."
101
 
102
  def main():
103
- # Server configuration to use Heroku-assigned port
104
- if 'PORT' in os.environ:
105
- #st.set_option('server.port', PORT)
106
- print(f"Server starting on port {PORT}")
107
 
108
  st.title("🤖 Trợ Lý AI Tiếng Việt")
 
109
 
110
- # Initialize chatbot
111
  chatbot = VietnameseChatbot()
112
 
113
  # Chat history in session state
114
  if 'messages' not in st.session_state:
115
  st.session_state.messages = []
116
 
 
 
 
 
 
 
117
  # Display chat messages
118
  for message in st.session_state.messages:
119
  with st.chat_message(message["role"]):
@@ -138,8 +165,5 @@ def main():
138
  # Add assistant message to chat history
139
  st.session_state.messages.append({"role": "assistant", "content": response})
140
 
141
- # Logging for Heroku diagnostics
142
- print("Chatbot application is initializing...")
143
-
144
  if __name__ == "__main__":
145
  main()
 
4
  from transformers import AutoTokenizer, AutoModel
5
  import numpy as np
6
  from sklearn.metrics.pairwise import cosine_similarity
7
+ import json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  class VietnameseChatbot:
10
+ def __init__(self, model_name='intfloat/multilingual-e5-small'):
11
  """
12
+ Initialize the Vietnamese chatbot with pre-loaded model and conversation data
13
  """
14
+ # Load pre-trained model and tokenizer
15
+ print("Loading tokenizer...")
16
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+
18
+ print("Loading model...")
19
+ self.model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16)
20
 
21
+ # Load comprehensive conversation dataset
22
+ self.conversation_data = self._load_conversation_data()
23
+
24
+ # Pre-compute embeddings for faster response generation
25
+ print("Pre-computing conversation embeddings...")
26
+ self.conversation_embeddings = self._precompute_embeddings()
27
+
28
+ def _load_conversation_data(self):
29
+ """
30
+ Load a comprehensive conversation dataset
31
+ """
32
+ return [
33
+ # Greeting conversations
34
+ {"query": "Xin chào", "response": "Chào bạn! Tôi có thể giúp gì cho bạn?"},
35
+ {"query": "Hi", "response": "Xin chào! Tôi là trợ lý AI tiếng Việt."},
36
+ {"query": "Chào buổi sáng", "response": "Chào buổi sáng! Chúc bạn một ngày tốt lành."},
37
+
38
+ # Identity and purpose
39
+ {"query": "Bạn là ai?", "response": "Tôi là trợ lý AI được phát triển để hỗ trợ và trò chuyện bằng tiếng Việt."},
40
+ {"query": "Bạn từ đâu đến?", "response": "Tôi được phát triển bởi một nhóm kỹ sư AI, và tôn chỉ của tôi là hỗ trợ con người."},
41
+
42
+ # Small talk
43
+ {"query": "Bạn thích gì?", "response": "Tôi thích học hỏi và giúp đỡ mọi người. Mỗi cuộc trò chuyện là một cơ hội để tôi phát triển."},
44
+ {"query": "Bạn có thể làm gì?", "response": "Tôi có thể trò chuyện, trả lời câu hỏi, và hỗ trợ bạn trong nhiều tình huống khác nhau."},
45
+
46
+ # Weather and time
47
+ {"query": "Thời tiết hôm nay thế nào?", "response": "Xin lỗi, tôi không thể cung cấp thông tin thời tiết trực tiếp. Bạn có thể kiểm tra ứng dụng dự báo thời tiết."},
48
+ {"query": "Bây giờ là mấy giờ?", "response": "Tôi là trợ lý AI, nên không thể xem đồng hồ. Bạn có thể kiểm tra thiết bị của mình."},
49
+
50
+ # Assistance offers
51
+ {"query": "Tôi cần trợ giúp", "response": "Tôi sẵn sàng hỗ trợ bạn. Bạn cần giúp gì?"},
52
+ {"query": "Giúp tôi với cái gì đó", "response": "Vâng, tôi có thể hỗ trợ bạn. Hãy cho tôi biết chi tiết hơn."},
53
+
54
+ # Farewell
55
+ {"query": "Tạm biệt", "response": "Hẹn gặp lại! Chúc bạn một ngày tốt đẹp."},
56
+ {"query": "Bye", "response": "Tạm biệt! Rất vui được trò chuyện với bạn."},
57
  ]
58
+
59
+ def _precompute_embeddings(self):
60
+ """
61
+ Pre-compute embeddings for all conversation queries
62
+ """
63
+ embeddings = []
64
+ for item in self.conversation_data:
65
+ embedding = self.embed_text(item['query'])
66
+ if embedding is not None:
67
+ embeddings.append(embedding[0])
68
+ return np.array(embeddings)
69
+
70
  def embed_text(self, text):
71
  """
72
  Generate embeddings for input text
73
  """
74
  try:
75
  # Tokenize and generate embeddings
76
+ inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True)
77
 
78
  with torch.no_grad():
79
+ model_output = self.model(**inputs)
80
 
81
  # Mean pooling
82
  embeddings = self.mean_pooling(model_output, inputs['attention_mask'])
 
84
  except Exception as e:
85
  print(f"Embedding error: {e}")
86
  return None
87
+
88
  def mean_pooling(self, model_output, attention_mask):
89
  """
90
  Perform mean pooling on model output
 
92
  token_embeddings = model_output[0]
93
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
94
  return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
95
+
96
  def get_response(self, user_query):
97
  """
98
  Find the most similar response from conversation data
 
102
  query_embedding = self.embed_text(user_query)
103
 
104
  if query_embedding is None:
105
+ return "Xin lỗi, đã có lỗi xảy ra khi phân tích câu hỏi của bạn."
 
 
 
 
 
106
 
107
  # Calculate cosine similarities
108
+ similarities = cosine_similarity(query_embedding, self.conversation_embeddings)[0]
109
 
110
  # Find most similar response
111
  best_match_index = np.argmax(similarities)
 
114
  if similarities[best_match_index] > 0.5:
115
  return self.conversation_data[best_match_index]['response']
116
 
117
+ return "Xin lỗi, tôi chưa hiểu câu hỏi của bạn. Bạn có thể diễn đạt lại được không?"
118
  except Exception as e:
119
  print(f"Response generation error: {e}")
120
  return "Đã xảy ra lỗi. Xin vui lòng thử lại."
121
 
122
  def main():
123
+ st.set_page_config(
124
+ page_title="Trợ AI Tiếng Việt",
125
+ page_icon="🤖",
126
+ )
127
 
128
  st.title("🤖 Trợ Lý AI Tiếng Việt")
129
+ st.caption("Trò chuyện với trợ lý AI được phát triển bằng mô hình đa ngôn ngữ")
130
 
131
+ # Initialize chatbot (this will pre-load models and embeddings)
132
  chatbot = VietnameseChatbot()
133
 
134
  # Chat history in session state
135
  if 'messages' not in st.session_state:
136
  st.session_state.messages = []
137
 
138
+ # Sidebar for additional information
139
+ with st.sidebar:
140
+ st.header("Về Trợ Lý AI")
141
+ st.write("Đây là một trợ lý AI được phát triển để hỗ trợ trò chuyện bằng tiếng Việt.")
142
+ st.write("Mô hình sử dụng: intfloat/multilingual-e5-small")
143
+
144
  # Display chat messages
145
  for message in st.session_state.messages:
146
  with st.chat_message(message["role"]):
 
165
  # Add assistant message to chat history
166
  st.session_state.messages.append({"role": "assistant", "content": response})
167
 
 
 
 
168
  if __name__ == "__main__":
169
  main()