Desalegnn commited on
Commit
03960ab
Β·
verified Β·
1 Parent(s): c0bc124

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -8
app.py CHANGED
@@ -10,10 +10,12 @@ from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
10
  from sklearn.model_selection import train_test_split
11
  from sklearn.metrics import classification_report
12
  from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification
 
13
  # Load BERT model and tokenizer via HuggingFace Transformers
14
  bert = XLMRobertaModel.from_pretrained('castorini/afriberta_large')
15
  tokenizer = XLMRobertaTokenizer.from_pretrained('castorini/afriberta_large')
16
- # Define the model architecture
 
17
  class BERT_Arch(nn.Module):
18
  def __init__(self, bert):
19
  super(BERT_Arch, self).__init__()
@@ -21,7 +23,7 @@ class BERT_Arch(nn.Module):
21
  self.dropout = nn.Dropout(0.1) # Dropout layer
22
  self.relu = nn.ReLU() # ReLU activation function
23
  self.fc1 = nn.Linear(768, 512) # Dense layer 1
24
- self.fc2 = nn.Linear(512, 2) # Dense layer 2 (Output layer)
25
  self.softmax = nn.LogSoftmax(dim=1) # Softmax activation function
26
 
27
  def forward(self, sent_id, mask): # Define the forward pass
@@ -39,19 +41,21 @@ fake_news_model_path = "Hate_Speech_model.pt"
39
  fake_news_model = torch.load(fake_news_model_path, map_location=torch.device('cpu'))
40
  fake_news_model.eval()
41
 
42
- # Function to detect fake news
 
 
 
43
  def detect_fake_news(text):
44
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
45
  with torch.no_grad():
46
  outputs = fake_news_model(inputs['input_ids'], inputs['attention_mask'])
47
  label = torch.argmax(outputs, dim=1).item()
48
- fake_news_result = "Hate" if label == 1 else "Real"
49
- return fake_news_result
50
 
51
  # Function to handle post logic
52
  def post_text(text, fake_news_result):
53
- if fake_news_result == "Hate":
54
- return "Your message contains Hate Speech and cannot be posted.", ""
55
  else:
56
  return "The text is safe to post.", text
57
 
@@ -64,7 +68,7 @@ with interface:
64
  with gr.Row():
65
  detect_fake_button = gr.Button("Detect Hate Speech")
66
  with gr.Row():
67
- fake_news_result_box = gr.Textbox(label="Hate speech Detection Result", interactive=False)
68
  with gr.Row():
69
  post_button = gr.Button("Post Text")
70
  with gr.Row():
 
10
  from sklearn.model_selection import train_test_split
11
  from sklearn.metrics import classification_report
12
  from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification
13
+
14
  # Load BERT model and tokenizer via HuggingFace Transformers
15
  bert = XLMRobertaModel.from_pretrained('castorini/afriberta_large')
16
  tokenizer = XLMRobertaTokenizer.from_pretrained('castorini/afriberta_large')
17
+
18
+ # Define the model architecture for three classes
19
  class BERT_Arch(nn.Module):
20
  def __init__(self, bert):
21
  super(BERT_Arch, self).__init__()
 
23
  self.dropout = nn.Dropout(0.1) # Dropout layer
24
  self.relu = nn.ReLU() # ReLU activation function
25
  self.fc1 = nn.Linear(768, 512) # Dense layer 1
26
+ self.fc2 = nn.Linear(512, 3) # Dense layer 2 (Output layer for 3 classes)
27
  self.softmax = nn.LogSoftmax(dim=1) # Softmax activation function
28
 
29
  def forward(self, sent_id, mask): # Define the forward pass
 
41
  fake_news_model = torch.load(fake_news_model_path, map_location=torch.device('cpu'))
42
  fake_news_model.eval()
43
 
44
+ # Mapping labels to classes
45
+ LABELS = {0: "Free", 1: "Hate", 2: "Offensive"}
46
+
47
+ # Function to detect hate speech
48
  def detect_fake_news(text):
49
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
50
  with torch.no_grad():
51
  outputs = fake_news_model(inputs['input_ids'], inputs['attention_mask'])
52
  label = torch.argmax(outputs, dim=1).item()
53
+ return LABELS[label]
 
54
 
55
  # Function to handle post logic
56
  def post_text(text, fake_news_result):
57
+ if fake_news_result in ["Hate", "Offensive"]:
58
+ return f"Your message contains {fake_news_result} Speech and cannot be posted.", ""
59
  else:
60
  return "The text is safe to post.", text
61
 
 
68
  with gr.Row():
69
  detect_fake_button = gr.Button("Detect Hate Speech")
70
  with gr.Row():
71
+ fake_news_result_box = gr.Textbox(label="Hate Speech Detection Result", interactive=False)
72
  with gr.Row():
73
  post_button = gr.Button("Post Text")
74
  with gr.Row():