Bui Trung
commited on
Upload 2 files
Browse files- app.py +193 -0
- phobert_fold1.pth +3 -0
app.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import jmespath
|
2 |
+
import asyncio
|
3 |
+
import json
|
4 |
+
from urllib.parse import urlencode
|
5 |
+
from typing import List, Dict
|
6 |
+
from httpx import AsyncClient, Response
|
7 |
+
from loguru import logger as log
|
8 |
+
import nest_asyncio
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from transformers import AutoModel
|
12 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
13 |
+
import gradio as gr
|
14 |
+
|
15 |
+
client = AsyncClient(
|
16 |
+
# enable http2
|
17 |
+
http2=True,
|
18 |
+
headers={
|
19 |
+
"Accept-Language": "en-US,en;q=0.9",
|
20 |
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36",
|
21 |
+
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8",
|
22 |
+
"Accept-Encoding": "gzip, deflate, br",
|
23 |
+
"content-type": "application/json"
|
24 |
+
},
|
25 |
+
)
|
26 |
+
|
27 |
+
def parse_comments(response: Response) -> Dict:
|
28 |
+
try:
|
29 |
+
data = json.loads(response.text)
|
30 |
+
except json.JSONDecodeError:
|
31 |
+
log.error(f"Failed to parse JSON response: {response.text}")
|
32 |
+
return {"comments": [], "total_comments": 0}
|
33 |
+
|
34 |
+
comments_data = data.get("comments", [])
|
35 |
+
total_comments = data.get("total", 0)
|
36 |
+
|
37 |
+
if not comments_data:
|
38 |
+
log.warning(f"No comments found in response: {response.text}")
|
39 |
+
return {"comments": [], "total_comments": total_comments}
|
40 |
+
|
41 |
+
parsed_comments = []
|
42 |
+
for comment in comments_data:
|
43 |
+
result = jmespath.search(
|
44 |
+
"""{
|
45 |
+
text: text
|
46 |
+
}""",
|
47 |
+
comment
|
48 |
+
)
|
49 |
+
parsed_comments.append(result)
|
50 |
+
return {"comments": parsed_comments, "total_comments": total_comments}
|
51 |
+
|
52 |
+
async def scrape_comments(post_id: int, comments_count: int = 20, max_comments: int = None) -> List[Dict]:
|
53 |
+
|
54 |
+
def form_api_url(cursor: int):
|
55 |
+
base_url = "https://www.tiktok.com/api/comment/list/?"
|
56 |
+
params = {
|
57 |
+
"aweme_id": post_id,
|
58 |
+
'count': comments_count,
|
59 |
+
'cursor': cursor # the index to start from
|
60 |
+
}
|
61 |
+
return base_url + urlencode(params)
|
62 |
+
|
63 |
+
log.info(f"Scraping comments from post ID: {post_id}")
|
64 |
+
first_page = await client.get(form_api_url(0))
|
65 |
+
data = parse_comments(first_page)
|
66 |
+
comments_data = data["comments"]
|
67 |
+
total_comments = data["total_comments"]
|
68 |
+
|
69 |
+
if not comments_data:
|
70 |
+
log.warning(f"No comments found for post ID {post_id}")
|
71 |
+
return []
|
72 |
+
if max_comments and max_comments < total_comments:
|
73 |
+
total_comments = max_comments
|
74 |
+
|
75 |
+
log.info(f"Scraping comments pagination, remaining {total_comments // comments_count - 1} more pages")
|
76 |
+
_other_pages = [
|
77 |
+
client.get(form_api_url(cursor=cursor))
|
78 |
+
for cursor in range(comments_count, total_comments + comments_count, comments_count)
|
79 |
+
]
|
80 |
+
|
81 |
+
for response in asyncio.as_completed(_other_pages):
|
82 |
+
response = await response
|
83 |
+
new_comments = parse_comments(response)["comments"]
|
84 |
+
comments_data.extend(new_comments)
|
85 |
+
|
86 |
+
# If we have reached or exceeded the maximum number of comments to scrape, stop the process
|
87 |
+
if max_comments and len(comments_data) >= max_comments:
|
88 |
+
comments_data = comments_data[:max_comments]
|
89 |
+
break
|
90 |
+
|
91 |
+
log.success(f"Scraped {len(comments_data)} comments from post ID {post_id}")
|
92 |
+
return comments_data
|
93 |
+
|
94 |
+
class SentimentClassifier(nn.Module):
|
95 |
+
def __init__(self, n_classes):
|
96 |
+
super(SentimentClassifier, self).__init__()
|
97 |
+
self.bert = AutoModel.from_pretrained("vinai/phobert-base")
|
98 |
+
self.drop = nn.Dropout(p=0.3)
|
99 |
+
self.fc = nn.Linear(self.bert.config.hidden_size, n_classes)
|
100 |
+
nn.init.normal_(self.fc.weight, std=0.02)
|
101 |
+
nn.init.normal_(self.fc.bias, 0)
|
102 |
+
|
103 |
+
def forward(self, input_ids, attention_mask):
|
104 |
+
last_hidden_state, output = self.bert(
|
105 |
+
input_ids=input_ids,
|
106 |
+
attention_mask=attention_mask,
|
107 |
+
return_dict=False # Dropout will errors if without this
|
108 |
+
)
|
109 |
+
|
110 |
+
x = self.drop(output)
|
111 |
+
x = self.fc(x)
|
112 |
+
return x
|
113 |
+
|
114 |
+
def infer(text, tokenizer, max_len=120):
|
115 |
+
encoded_review = tokenizer.encode_plus(
|
116 |
+
text,
|
117 |
+
max_length=max_len,
|
118 |
+
truncation=True,
|
119 |
+
add_special_tokens=True,
|
120 |
+
padding='max_length',
|
121 |
+
return_attention_mask=True,
|
122 |
+
return_token_type_ids=False,
|
123 |
+
return_tensors='pt',
|
124 |
+
)
|
125 |
+
|
126 |
+
input_ids = encoded_review['input_ids'].to(device)
|
127 |
+
attention_mask = encoded_review['attention_mask'].to(device)
|
128 |
+
|
129 |
+
output = model(input_ids, attention_mask)
|
130 |
+
_, y_pred = torch.max(output, dim=1)
|
131 |
+
|
132 |
+
return class_names[y_pred]
|
133 |
+
|
134 |
+
async def predict_comments(video_id):
|
135 |
+
comments = await scrape_comments(
|
136 |
+
post_id=int(video_id),
|
137 |
+
max_comments=2000,
|
138 |
+
comments_count=20
|
139 |
+
)
|
140 |
+
predictions = []
|
141 |
+
for comment in comments:
|
142 |
+
text = comment['text']
|
143 |
+
probs = infer(text, tokenizer)
|
144 |
+
predictions.append({'comment': text, 'predictions': probs})
|
145 |
+
|
146 |
+
# Tính toán tỷ lệ phần trăm của mỗi nhãn
|
147 |
+
total_comments = len(predictions)
|
148 |
+
label_counts = [0, 0, 0] # Assuming there are 3 labels
|
149 |
+
comment_off = []
|
150 |
+
comment_hate = []
|
151 |
+
for prediction in predictions:
|
152 |
+
probs = prediction['predictions']
|
153 |
+
if probs == 'CLEAN':
|
154 |
+
label_counts[0] += 1
|
155 |
+
elif probs == 'OFFENSIVE':
|
156 |
+
label_counts[1] += 1
|
157 |
+
comment_off.append(prediction['comment'])
|
158 |
+
else :
|
159 |
+
label_counts[2] += 1
|
160 |
+
comment_hate.append(prediction['comment'])
|
161 |
+
|
162 |
+
label_percentages = [count / total_comments * 100 for count in label_counts]
|
163 |
+
results = {
|
164 |
+
'total_comments': total_comments,
|
165 |
+
'label_percentages': {
|
166 |
+
'CLEAN': label_percentages[0],
|
167 |
+
'OFFENSIVE': label_percentages[1],
|
168 |
+
'HATE': label_percentages[2],
|
169 |
+
'CMT OFFENSIVE': comment_off,
|
170 |
+
'CMT HATE': comment_hate,
|
171 |
+
}
|
172 |
+
}
|
173 |
+
|
174 |
+
return results
|
175 |
+
|
176 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
177 |
+
|
178 |
+
model = SentimentClassifier(n_classes=3)
|
179 |
+
model.to(device)
|
180 |
+
model.load_state_dict(torch.load('phobert_fold1.pth'))
|
181 |
+
|
182 |
+
tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base")
|
183 |
+
|
184 |
+
class_names = ['CLEAN', 'OFFENSIVE', 'HATE']
|
185 |
+
|
186 |
+
|
187 |
+
iface = gr.Interface(
|
188 |
+
fn=predict_comments,
|
189 |
+
inputs="text",
|
190 |
+
outputs="json"
|
191 |
+
)
|
192 |
+
|
193 |
+
iface.launch()
|
phobert_fold1.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:275266c8ce8e44222f98cccd875b1e38c8fd677b89a90a147dcc3fdd4de55dfd
|
3 |
+
size 540085102
|