Update app.py
Browse files
app.py
CHANGED
@@ -10,10 +10,12 @@ from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
|
|
10 |
from sklearn.model_selection import train_test_split
|
11 |
from sklearn.metrics import classification_report
|
12 |
from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification
|
|
|
13 |
# Load BERT model and tokenizer via HuggingFace Transformers
|
14 |
bert = XLMRobertaModel.from_pretrained('castorini/afriberta_large')
|
15 |
tokenizer = XLMRobertaTokenizer.from_pretrained('castorini/afriberta_large')
|
16 |
-
|
|
|
17 |
class BERT_Arch(nn.Module):
|
18 |
def __init__(self, bert):
|
19 |
super(BERT_Arch, self).__init__()
|
@@ -21,7 +23,7 @@ class BERT_Arch(nn.Module):
|
|
21 |
self.dropout = nn.Dropout(0.1) # Dropout layer
|
22 |
self.relu = nn.ReLU() # ReLU activation function
|
23 |
self.fc1 = nn.Linear(768, 512) # Dense layer 1
|
24 |
-
self.fc2 = nn.Linear(512,
|
25 |
self.softmax = nn.LogSoftmax(dim=1) # Softmax activation function
|
26 |
|
27 |
def forward(self, sent_id, mask): # Define the forward pass
|
@@ -39,19 +41,21 @@ fake_news_model_path = "Hate_Speech_model.pt"
|
|
39 |
fake_news_model = torch.load(fake_news_model_path, map_location=torch.device('cpu'))
|
40 |
fake_news_model.eval()
|
41 |
|
42 |
-
#
|
|
|
|
|
|
|
43 |
def detect_fake_news(text):
|
44 |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
|
45 |
with torch.no_grad():
|
46 |
outputs = fake_news_model(inputs['input_ids'], inputs['attention_mask'])
|
47 |
label = torch.argmax(outputs, dim=1).item()
|
48 |
-
|
49 |
-
return fake_news_result
|
50 |
|
51 |
# Function to handle post logic
|
52 |
def post_text(text, fake_news_result):
|
53 |
-
if fake_news_result
|
54 |
-
return "Your message contains
|
55 |
else:
|
56 |
return "The text is safe to post.", text
|
57 |
|
@@ -64,7 +68,7 @@ with interface:
|
|
64 |
with gr.Row():
|
65 |
detect_fake_button = gr.Button("Detect Hate Speech")
|
66 |
with gr.Row():
|
67 |
-
fake_news_result_box = gr.Textbox(label="Hate
|
68 |
with gr.Row():
|
69 |
post_button = gr.Button("Post Text")
|
70 |
with gr.Row():
|
|
|
10 |
from sklearn.model_selection import train_test_split
|
11 |
from sklearn.metrics import classification_report
|
12 |
from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification
|
13 |
+
|
14 |
# Load BERT model and tokenizer via HuggingFace Transformers
|
15 |
bert = XLMRobertaModel.from_pretrained('castorini/afriberta_large')
|
16 |
tokenizer = XLMRobertaTokenizer.from_pretrained('castorini/afriberta_large')
|
17 |
+
|
18 |
+
# Define the model architecture for three classes
|
19 |
class BERT_Arch(nn.Module):
|
20 |
def __init__(self, bert):
|
21 |
super(BERT_Arch, self).__init__()
|
|
|
23 |
self.dropout = nn.Dropout(0.1) # Dropout layer
|
24 |
self.relu = nn.ReLU() # ReLU activation function
|
25 |
self.fc1 = nn.Linear(768, 512) # Dense layer 1
|
26 |
+
self.fc2 = nn.Linear(512, 3) # Dense layer 2 (Output layer for 3 classes)
|
27 |
self.softmax = nn.LogSoftmax(dim=1) # Softmax activation function
|
28 |
|
29 |
def forward(self, sent_id, mask): # Define the forward pass
|
|
|
41 |
fake_news_model = torch.load(fake_news_model_path, map_location=torch.device('cpu'))
|
42 |
fake_news_model.eval()
|
43 |
|
44 |
+
# Mapping labels to classes
|
45 |
+
LABELS = {0: "Free", 1: "Hate", 2: "Offensive"}
|
46 |
+
|
47 |
+
# Function to detect hate speech
|
48 |
def detect_fake_news(text):
|
49 |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
|
50 |
with torch.no_grad():
|
51 |
outputs = fake_news_model(inputs['input_ids'], inputs['attention_mask'])
|
52 |
label = torch.argmax(outputs, dim=1).item()
|
53 |
+
return LABELS[label]
|
|
|
54 |
|
55 |
# Function to handle post logic
|
56 |
def post_text(text, fake_news_result):
|
57 |
+
if fake_news_result in ["Hate", "Offensive"]:
|
58 |
+
return f"Your message contains {fake_news_result} Speech and cannot be posted.", ""
|
59 |
else:
|
60 |
return "The text is safe to post.", text
|
61 |
|
|
|
68 |
with gr.Row():
|
69 |
detect_fake_button = gr.Button("Detect Hate Speech")
|
70 |
with gr.Row():
|
71 |
+
fake_news_result_box = gr.Textbox(label="Hate Speech Detection Result", interactive=False)
|
72 |
with gr.Row():
|
73 |
post_button = gr.Button("Post Text")
|
74 |
with gr.Row():
|