Spaces:
Sleeping
Sleeping
DataRaptor
commited on
Commit
·
38f2246
1
Parent(s):
3321385
Upload 7 files
Browse files- app.py +180 -1
- dataset.py +61 -0
- image/wandb.jpg +0 -0
- model.py +115 -0
- model_ind2cat.csv +5 -0
- requirements.txt +11 -0
- utils.py +20 -0
app.py
CHANGED
@@ -1,6 +1,185 @@
|
|
|
|
1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
|
|
1 |
+
import pandas as pd
|
2 |
import streamlit as st
|
3 |
+
from dataset import BanglaHSDataset, get_class
|
4 |
+
import torch
|
5 |
+
from model import HSLanguageModel
|
6 |
+
import torch.nn as nn
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
|
10 |
+
@st.cache_resource
|
11 |
+
def load_model():
|
12 |
+
model = HSLanguageModel(
|
13 |
+
backbone = 'sagorsarker/bangla-bert-base',
|
14 |
+
target_size = 4,
|
15 |
+
head_dropout = 0,
|
16 |
+
reinit_nlayers = 0,
|
17 |
+
freeze_nlayers = 0,
|
18 |
+
reinit_head = True,
|
19 |
+
grad_checkpointing = False,
|
20 |
+
)
|
21 |
+
model.load_state_dict(torch.load('model_weights.pth', map_location=torch.device('cpu')))
|
22 |
+
model.eval()
|
23 |
+
ds = BanglaHSDataset(model.tokenizer, 256)
|
24 |
+
return model, ds
|
25 |
|
26 |
|
27 |
+
def infer(text):
|
28 |
+
model, ds = load_model()
|
29 |
+
model.eval()
|
30 |
+
text = ds[text][0]
|
31 |
+
with torch.no_grad():
|
32 |
+
out = model(text)
|
33 |
+
out = nn.Softmax()(out).squeeze()
|
34 |
+
return out
|
35 |
+
|
36 |
+
@st.cache_data
|
37 |
+
def get_image():
|
38 |
+
image = Image.open('./image/wandb.jpg')
|
39 |
+
return image
|
40 |
+
|
41 |
+
|
42 |
+
examples = [
|
43 |
+
"আলোচনায় যুক্ত হতে গেলে নিজের ধর্ম সম্পর্কে জানতে হয় অন্যথায় বোকা ধর্ম আরো বোকা হয়ে যায়",
|
44 |
+
"ইন্ডিয়ায় মালাউনরা কুওার সাথে বিবাহ করছে গরুর মূত খায়",
|
45 |
+
"দাতের বাগান ভাংগা দরকার শালি ভিকারি ভিকারি হয় ফালতু",
|
46 |
+
"মনটায় চায়তাছে তরে পাবনার পাগলা গারদে নিয়ে চুদাই",
|
47 |
+
"আমি হাতে নাতে ধরেছি বা নিজের চোখে দেখেছি আমার বউ আর আমার শালার পরকীয়া",
|
48 |
+
"ভারতের ত্রকটা প্রদেশের সমান পাকিশ্তান ভাই কে কাকে শামলায় পাকিশ্তান বাংলাদেশের সাথেইতো পাড়বেনা ৷",
|
49 |
+
"ভারত পাকিস্তানীর খবর নাই বাংলাদেশীর ঘুম নাই।",
|
50 |
+
"যেভাবে মিডিয়ার মাধ্যমে রানু মন্ডল ভাইরাল হয়েছিল সের মিডিয়ায় আবার আনন্দলকে রেলস্টেশনে ভিখারীর কাতারে বসিয়ে দিতে পারেন অন্য বিকার এদের মাথায় উকুন বেশি নয় কারণ রানুস মন্ডল জানেন রানু মন্ডল এর মাথায় উকুন বেশি তাই উনার মাথার উকুন অন্যদের মাথায় উঠবে বলে তাই এর রানু মন্ডল এদের কাছ থেকে দূরে থাকবেন",
|
51 |
+
"ভাই নোয়াখালীর এক মহিলাকে কিছুক্ষণ আগে ছাএলীগ ধর্ষণ করছে,ঐ ভিডিও টা আছে আমার কাছে",
|
52 |
+
"এসব আবাল গুলা আমাদের সুন্দর বিনোদন থেকে দূরে রাখে।",
|
53 |
+
"সব মেয়ে বারো ভাতারি হয় না ভাই, কিছু কিছু মেয়ে ১০০ ভাতারিও হয়",
|
54 |
+
"আওয়ামী লীগের রাজনীতি শেষ তাড়াতাড়ি সরে পড়ুন ছাত্র লীগের ছাত্র দের বলছি ।",
|
55 |
+
"এমন অ প্রিয় সত্য ইতিহাস তু লে ধরার জন্য অাপনা কে হয় রাজাকার নয় তো ভারত বি রোধী বা পা কিস্থানপ ন্থি উপা ধি বহন কর তে হ তে পা রে বলে ছি লেন বাঙ্গালী দের কেউ দাবায়া রাখ তে পা রে নি পার বেনা অাপনা দের মত অদম্য সাহসী দেরও কেউ দাবায়া রাখ তে পার বেনা ইনশাঅাল্লাহ্ ।",
|
56 |
+
"উনি কোন দেশের কথা বলচে বাংলাদেশ নাকি ভারত",
|
57 |
+
"তোদের কোনো ধর্ম আছে আকাটা নোমো একজনে ধর্ম লিখে খেছে তা তোরা পালন কর তোদের ধর্মে ভগবান হলো সাপ, বানর, হাতি, গরু, ধন, ছামা ইত্যাদি এখন বল তোদের ধর্ম কি আর মুসলমানের ধর্ম কি নিজে একটু ভাবতোতাই না যেনে কথা বলবি না",
|
58 |
+
"বাংলাদেশ ও পাকিস্তান দখল করে অখন্ড ভারত হবে, দখল পরে আগে বগলের বাল কাটতে শিখো",
|
59 |
+
"শ্রী কৃষ্ণ মানুষের গর্বে জন্মেছে ۔একজন জন্ম নেয়া মানুষ কোনোদিন সৃষ্টি কর্তা হতে পারে না",
|
60 |
+
"ভারত পাকিস্তান সব দালালেরা চলে যান",
|
61 |
+
"হামলা ভারত নিয়ন্ত্রিত কাশ্মীরে সেনা ছাউনিতে হামলা বাংলাদেশে চলে ভারত পাকিস্তান নিয়ে কথার মামলা।",
|
62 |
+
"১৯৭১ এ ভারত সাহায্য না করলে আজকে এই বাংলা ভাষায় কথা বলতে পারতি না, ফাকিস্তানের পা চাটতি।",
|
63 |
+
"নো পোরোবলেম ভারতের বিতর এর অংশ আছে তাই তারা অবৈধ না",
|
64 |
+
"বি এম পি কে তো জনগন চায়না তা হলে নিরপক্ষ সরকারের অদিনে নিরবাচন দেয়া হোক দেখি আওয়ামীলীগ শতকরা কত ভোট পায় শতকরা ১০ ভোট পাবেনা আওমীলীগ",
|
65 |
+
"জামাইষষ্ঠীর নেমন্তন্ন না পাওয়ায় উঠোনে শুকোতে দেওয়া শ^শুরের লূঙ্গি চুরি করে পলাতক জামাই",
|
66 |
+
"ভারতে একটি মুসলমানকে থাকতে দেওয়া হবে না ।",
|
67 |
+
"আপনাদের মুসুলমান গুলি শুধু নাম ধারি মুসুলমান এদের মধ্যে প্রকৃত ধর্ম শিক্ষা নেই",
|
68 |
+
"যম যমী কি করেছিল?বোন একটু পড়াশোনা করুনতর্কে জিততে নাআপনার পূর্বপুরুষ এর বেলায় এক নিয়ম এখন একরকম সেটা সম্ভব নাআপনার মতো অসুস্থ দ্বিচারী মানুষ ই বলে",
|
69 |
+
"আল্লার কসম লাগে আপনে যদি আল্লহকে ভয় পেয়ে থাকেন তাহলে,প্রত্যেক মসজিদর ইমাম ও মোয়াজ্জেমের বেতন সরকারি করে দিন।",
|
70 |
+
"এমন মানুষ থাকার চেয়ে না থাকা ভালো",
|
71 |
+
"আপনি ঐসব ফালতু মনগড়া ধর্মকাহিনী না পড়ে একটু লেখাপড়া করুন আর নিজেকে মহাজ্ঞানী ভাবা বন্ধ করুন, যদিও আপনার অবস্থাটা বুঝি এইসব গল্প কেনো মন দিয়ে পড়েন",
|
72 |
+
"এই পরিবারটি দিরাই সুনামগঞ্জ এর ভাতি এলাকায় পরেছে। এই পরিবারের যিনি মুক্তিযোদ্ধা ছিলেন উনি বছর হল মারা গেছেন। উনি দিরাই ইউনিয়নের বছরের সাবেক চ্যায়ারমেন ছিলেন এবং সাবেক মুক্তিযোদ্ধার কমান্ডার ছিলেন। কিন্তু উনি মারা যাওয়ার পর থেকে এই আমানুবিক অত্তাচার শুরু করে কিচু রাজাকার আলবদররা।",
|
73 |
+
"মুসলিম রা হল অসুর আমাদের দেব দেবীর কে গালাগালি করছে।",
|
74 |
+
"একটা জিনিস বুঝলাম না,চোর,গুন্ডা,সন্ত্রাস,ধর্ষক,চাদাবাজ এরা আবার হাজি হয় কিভাবে?",
|
75 |
+
"আওয়ামীলীগ চেতনাবাজ ব্যবসায়ীরা কোথায়। বাংলাদেশের দেশপ্রেমিক মুক্তিযোদ্ধারা মরলে আওয়ামীলীগ এর কি? আওয়ামীলীগ এর চেতনা শুধু আওয়ামীলীগ মুক্তিযোদ্ধা হিন্দু মুক্তিযোদ্ধা ভারতীয় মুক্তিযোদ্ধা দের জন্য চেতনা খাডা হয়ে যায়। যেমন আওয়ামীলীগ রাজাকার হাসিনার বেয়াই চেতনার ঠেলায় এখন ভুয়া মুক্তিযোদ্ধা হয়েগেছে।",
|
76 |
+
"সাবধানে বোনেরা! ১৪ ফেব্রুয়ারি ভালোবাসা দিবসে অতিরিক্ত ভালোবাসা খেয়ে বদহজম হয়ে বমিবমি ভাব যেন না হয়",
|
77 |
+
"খালেদা জিয়ার পিছনে নয়াদিগন্তর সাংবাদিক ছাড়া হয়ত আর কাউকে পাওয়া যাবে না",
|
78 |
+
"মালাউন হিসেবে নয় একজন মানুষ হিসেবে বলছি বাংলাদেশ সরকারের একটু চাপ দেওয়া দরকার ছিল",
|
79 |
+
"এইটা কী তরবাপের দেশ নাকী জামনেচায় তাকরবি তদেরমত মালাউনদের কুতাদিয়া চুদা দরকার",
|
80 |
+
"এই মহিলা যথেষ্ট বেয়াদব এবং মূর্খ্য ও বটে।",
|
81 |
+
"নব্বই দশকের বাংলা ছবি, ফুলশয্যা রাত, লাইট অফ , কপালে চুমু, পরের দিন সকাল",
|
82 |
+
"বাঁশ বাগানের মালিক কি চেটের বাল",
|
83 |
+
"শিরক হলো সবচেয়ে বড় পাপ.....। সাকিব আল হাসান কেন সেটা করতে সাহায্য করলো??",
|
84 |
+
"তর জন্ম কোথায় হইছে পতিতালয়ে নাকি বস্তি পাড়ায়। জারজের বাচ্চা তর বাবা কয়টা তর মায়ের কাছে উত্তরটা জানবি। তোকে জারজ বললেও জারজ কে অপমান করা হবে তুই একটা নিকৃষ্ট জানোয়ারের বাচ্চা। নিকৃষ্ট দেশের নিকৃষ্ট জাতি মালাউনরা।",
|
85 |
+
"স্যার, ১০ টাকায় চাল আর ঘরে ঘরে চাকুরী দেয়ার কথা ছিল যে সেটার কি হল??",
|
86 |
+
"চরমোনাইর কারণে দেশে ইসলামের আইন আসতেছে না।.....দালাল....দালাল....",
|
87 |
+
"মালাউনেরা বাংলাদেশে এতো ঘন ঘন কেনো আসে ?",
|
88 |
+
"আমি ও তো জানি যে তুই কি তুই একটা খানকি, মাগী, কালো চুদি",
|
89 |
+
"কুত্তালীগ সেলিমের আইনের আওতায় এনে কঠিন শাস্তি হওয়া উচিত",
|
90 |
+
"দাদা এটা কোথাকার জায়গা বলুন আমাকে চুঁদতে হবে",
|
91 |
+
"সালা বেইমান বাংলাদেশের বরিশালের ছেলে গান গায় ভারতের তার মুখে কোনো কথাই মানায়না।",
|
92 |
+
"ভারত যখন দেখলো অনেক হিনদু ইসলাম গ্রহন করেছে তখনই জাকির নায়েকের পিছনে লাগছে",
|
93 |
+
"পতিতাদের জীবন খুব দুঃখের প্লিজ ওদের নিয়ে কোনো মজা করবেন না ওরা পেটের দায়ে এই খারাপ কাজ করে তাই ওদেরকে নিয়ে কোনো বাজে কথা বলবেন না",
|
94 |
+
"হিন্দুদের আঘাত করলে ভারতের জ্বলে।",
|
95 |
+
"মাগির ছেলে তোকে কে জেতে বলেছে",
|
96 |
+
"যখন দেখবে কলকাতার দাদারা আমার প্রশংসা করছে, তখনই বুঝবে যে আমি আমার দেশের বিরুদ্ধে কাজ করছি",
|
97 |
+
"দুর্নীতিবাজ পুলিশ দিয়ে অভিযানের নামে সাধারণ মানুষকেই হয়রানি করা হবে",
|
98 |
+
]
|
99 |
+
|
100 |
+
|
101 |
+
st.set_page_config(
|
102 |
+
page_title="HateGuard",
|
103 |
+
page_icon="🧊",
|
104 |
+
layout="centered",
|
105 |
+
initial_sidebar_state="expanded",
|
106 |
+
)
|
107 |
+
|
108 |
+
|
109 |
+
# fix sidebar
|
110 |
+
st.markdown("""
|
111 |
+
<style>
|
112 |
+
.css-vk3wp9 {
|
113 |
+
background-color: rgb(255 255 255);
|
114 |
+
}
|
115 |
+
.css-18l0hbk {
|
116 |
+
padding: 0.34rem 1.2rem !important;
|
117 |
+
margin: 0.125rem 2rem;
|
118 |
+
}
|
119 |
+
.css-nziaof {
|
120 |
+
padding: 0.34rem 1.2rem !important;
|
121 |
+
margin: 0.125rem 2rem;
|
122 |
+
background-color: rgb(181 197 227 / 18%) !important;
|
123 |
+
}
|
124 |
+
</style>
|
125 |
+
""", unsafe_allow_html=True
|
126 |
+
)
|
127 |
+
hide_st_style = """
|
128 |
+
<style>
|
129 |
+
#MainMenu {visibility: hidden;}
|
130 |
+
footer {visibility: hidden;}
|
131 |
+
header {visibility: hidden;}
|
132 |
+
</style>
|
133 |
+
"""
|
134 |
+
st.markdown(hide_st_style, unsafe_allow_html=True)
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
def app():
|
139 |
+
|
140 |
+
st.title("HateGuard")
|
141 |
+
|
142 |
+
|
143 |
+
st.markdown(
|
144 |
+
"""A Bangla hate speech detection system using pre-trained Transformer Encoder. Used FGM,
|
145 |
+
AWP, Layer freezing and Attention Pooling for improving performance. A total of 21 experiment was carried out achieving 82% accuracy.
|
146 |
+
"""
|
147 |
+
)
|
148 |
+
|
149 |
+
|
150 |
+
image = get_image()
|
151 |
+
st.image(image, use_column_width=True)
|
152 |
+
st.markdown("[![View in W&B](https://img.shields.io/badge/View%20in-W%26B-blue)](https://wandb.ai/shamim/HateGuard)")
|
153 |
+
|
154 |
+
st.markdown('---')
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
l = [f'Example {i}' for i in range(len(examples))]
|
159 |
+
text = st.selectbox("Examples", l)
|
160 |
+
idx = l.index(text)
|
161 |
+
text = st.text_area('Enter Text', examples[idx])
|
162 |
+
|
163 |
+
|
164 |
+
if st.button("Predict Scores", type="primary"):
|
165 |
+
with st.spinner("Predicting scores..."):
|
166 |
+
prob = infer(text).numpy()
|
167 |
+
st.success("Scores predicted successfully!")
|
168 |
+
|
169 |
+
idx = np.argpartition(prob, -4)[-4:]
|
170 |
+
st.markdown('#### Results')
|
171 |
+
|
172 |
+
idx = list(idx)
|
173 |
+
idx.sort(key=lambda x: prob[x].astype(float), reverse=True)
|
174 |
+
for i in idx:
|
175 |
+
class_name = get_class(i).capitalize()
|
176 |
+
class_probability = prob[i].astype(float)
|
177 |
+
st.write(f'{class_name}: {class_probability:.2%}')
|
178 |
+
st.progress(class_probability)
|
179 |
+
|
180 |
+
app()
|
181 |
+
|
182 |
+
# Display a footer with links and credits
|
183 |
+
st.markdown("---")
|
184 |
+
st.markdown("Built by [Shamim Ahamed](https://www.shamimahamed.com/). Dataset used [Bengali Hate v2.0](https://github.com/rezacsedu/Bengali-Hate-Speech-Dataset/blob/main/bengali_hate_v2.0.csv)")
|
185 |
|
dataset.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
from transformers import AutoConfig, AutoModel, AutoTokenizer
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from utils import read_yaml
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
class BanglaHSDataset(Dataset):
|
11 |
+
def __init__(self, tokenizer, max_length):
|
12 |
+
self.tokenizer = tokenizer
|
13 |
+
self.max_length = max_length
|
14 |
+
|
15 |
+
|
16 |
+
def __len__(self): return 0
|
17 |
+
|
18 |
+
def __getitem__(self, text):
|
19 |
+
inputs = self.tokenizer(
|
20 |
+
text,
|
21 |
+
max_length=self.max_length, padding='max_length',
|
22 |
+
truncation=True,
|
23 |
+
return_offsets_mapping=False
|
24 |
+
)
|
25 |
+
for k, v in inputs.items(): inputs[k] = torch.tensor(v, dtype=torch.long).unsqueeze(dim=0)
|
26 |
+
label = torch.tensor(0, dtype=torch.float)
|
27 |
+
return inputs, label
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
def get_class(index):
|
33 |
+
ind2cat = [
|
34 |
+
'Geopolitical',
|
35 |
+
'Personal',
|
36 |
+
'Political',
|
37 |
+
'Religious',
|
38 |
+
]
|
39 |
+
return ind2cat[index]
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
if __name__ == '__main__':
|
46 |
+
cfg = read_yaml('./baseline.yaml')
|
47 |
+
|
48 |
+
# cfg.Model.target_size = 6
|
49 |
+
# model = BanglaHS_Model(cfg.Model)
|
50 |
+
# #model.load_state_dict(torch.load('./model_fold-0_best.pt', map_location=torch.device('cpu')))
|
51 |
+
# model.eval()
|
52 |
+
|
53 |
+
# ds = BanglaHSDataset(cfg.Dataset, model)
|
54 |
+
|
55 |
+
# x = ds['Hello hi'][0]
|
56 |
+
|
57 |
+
# with torch.no_grad():
|
58 |
+
# y = model(x)
|
59 |
+
# print('y:', y)
|
60 |
+
|
61 |
+
|
image/wandb.jpg
ADDED
model.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
from transformers import AutoConfig, AutoModel, AutoTokenizer
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def weight_init_normal(module, model):
|
7 |
+
if isinstance(module, nn.Linear):
|
8 |
+
module.weight.data.normal_(mean=0.0, std=model.config.initializer_range)
|
9 |
+
if module.bias is not None:
|
10 |
+
module.bias.data.zero_()
|
11 |
+
elif isinstance(module, nn.Embedding):
|
12 |
+
module.weight.data.normal_(mean=0.0, std=model.config.initializer_range)
|
13 |
+
if module.padding_idx is not None:
|
14 |
+
module.weight.data[module.padding_idx].zero_()
|
15 |
+
elif isinstance(module, nn.LayerNorm):
|
16 |
+
module.bias.data.zero_()
|
17 |
+
module.weight.data.fill_(1.0)
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
class MeanPooling(nn.Module):
|
22 |
+
def __init__(self):
|
23 |
+
super(MeanPooling, self).__init__()
|
24 |
+
|
25 |
+
def forward(self, last_hidden_state, attention_mask):
|
26 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
|
27 |
+
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
|
28 |
+
sum_mask = input_mask_expanded.sum(1)
|
29 |
+
sum_mask = torch.clamp(sum_mask, min=1e-9)
|
30 |
+
mean_embeddings = sum_embeddings / sum_mask
|
31 |
+
return mean_embeddings
|
32 |
+
|
33 |
+
|
34 |
+
class MeanPoolingLayer(nn.Module):
|
35 |
+
def __init__(self,
|
36 |
+
hidden_size,
|
37 |
+
target_size,
|
38 |
+
dropout = 0,
|
39 |
+
):
|
40 |
+
super(MeanPoolingLayer, self).__init__()
|
41 |
+
self.pool = MeanPooling()
|
42 |
+
self.fc = nn.Sequential(
|
43 |
+
nn.Dropout(dropout),
|
44 |
+
nn.Linear(hidden_size, target_size),
|
45 |
+
nn.Sigmoid()
|
46 |
+
)
|
47 |
+
|
48 |
+
def forward(self, inputs, mask):
|
49 |
+
last_hidden_states = inputs[0]
|
50 |
+
feature = self.pool(last_hidden_states, mask)
|
51 |
+
outputs = self.fc(feature)
|
52 |
+
return outputs
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
class HSLanguageModel(nn.Module):
|
57 |
+
def __init__(self,
|
58 |
+
backbone = 'microsoft/deberta-v3-small',
|
59 |
+
target_size = 1,
|
60 |
+
head_dropout = 0,
|
61 |
+
reinit_nlayers = 0,
|
62 |
+
freeze_nlayers = 0,
|
63 |
+
reinit_head = True,
|
64 |
+
grad_checkpointing = False,
|
65 |
+
):
|
66 |
+
super(HSLanguageModel, self).__init__()
|
67 |
+
|
68 |
+
self.config = AutoConfig.from_pretrained(backbone, output_hidden_states=True)
|
69 |
+
self.model = AutoModel.from_pretrained(backbone, config=self.config)
|
70 |
+
self.head = MeanPoolingLayer(
|
71 |
+
self.config.hidden_size,
|
72 |
+
target_size,
|
73 |
+
head_dropout
|
74 |
+
)
|
75 |
+
self.tokenizer = AutoTokenizer.from_pretrained(backbone);
|
76 |
+
|
77 |
+
|
78 |
+
if grad_checkpointing == True:
|
79 |
+
print('Gradient ckpt enabled')
|
80 |
+
self.model.gradient_checkpointing_enable()
|
81 |
+
|
82 |
+
if reinit_nlayers > 0:
|
83 |
+
# Reinit last n encoder layers
|
84 |
+
# [TODO] Check if it is autoencoding model: Bert, Roberta, DistilBert, Albert, XLMRoberta, BertModel
|
85 |
+
for layer in self.model.encoder.layer[-reinit_nlayers:]:
|
86 |
+
self._init_weights(layer)
|
87 |
+
|
88 |
+
if freeze_nlayers > 0:
|
89 |
+
self.model.embeddings.requires_grad_(False)
|
90 |
+
self.model.encoder.layer[:freeze_nlayers].requires_grad_(False)
|
91 |
+
|
92 |
+
if reinit_head:
|
93 |
+
# Reinit layers in head
|
94 |
+
self._init_weights(self.head)
|
95 |
+
|
96 |
+
|
97 |
+
def _init_weights(self, layer):
|
98 |
+
for module in layer.modules():
|
99 |
+
init_fn = weight_init_normal
|
100 |
+
init_fn(module, self)
|
101 |
+
|
102 |
+
|
103 |
+
def forward(self, inputs):
|
104 |
+
outputs = self.model(**inputs)
|
105 |
+
outputs = self.head(outputs, inputs['attention_mask'])
|
106 |
+
return outputs
|
107 |
+
|
108 |
+
|
109 |
+
if __name__ == '__main__':
|
110 |
+
|
111 |
+
model = HSLanguageModel()
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
|
model_ind2cat.csv
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
,labels
|
2 |
+
0,Geopolitical
|
3 |
+
1,Personal
|
4 |
+
2,Political
|
5 |
+
3,Religious
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit==1.21.0
|
2 |
+
Pillow
|
3 |
+
protobuf
|
4 |
+
torchvision==0.15.2
|
5 |
+
torch==2.0.1
|
6 |
+
numpy
|
7 |
+
pandas
|
8 |
+
transformers==4.21.2
|
9 |
+
tokenizers==0.12.1
|
10 |
+
transformers[sentencepiece]
|
11 |
+
clean-text==0.6.0
|
utils.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Tue Jul 11 12:56:00 2023
|
5 |
+
|
6 |
+
Copyright (c): Shamim Ahamed
|
7 |
+
"""
|
8 |
+
|
9 |
+
import yaml
|
10 |
+
from addict import Dict
|
11 |
+
|
12 |
+
class MyDict(Dict):
|
13 |
+
def __missing__(self, name):
|
14 |
+
raise KeyError(name)
|
15 |
+
|
16 |
+
|
17 |
+
def read_yaml(fpath):
|
18 |
+
with open(fpath, mode="r") as file:
|
19 |
+
yml = yaml.safe_load(file)
|
20 |
+
return MyDict(yml)
|