|
import streamlit as st |
|
from transformers import pipeline |
|
from textblob import TextBlob |
|
from transformers import BertForSequenceClassification, AdamW, BertConfig |
|
st.set_page_config(layout='wide', initial_sidebar_state='expanded') |
|
col1, col2= st.columns(2) |
|
with col2: |
|
text = st.text_input("Enter the text you'd like to analyze for spam.") |
|
aButton = st.button('Analyze') |
|
with col1: |
|
st.title("Spamd: Turkish Spam Detector") |
|
st.markdown("Message spam detection tool for Turkish language. Due the small size of the dataset, I decided to go with transformers technology Google BERT. Using the Turkish pre-trained model BERTurk, I imporved the accuracy of the tool by 18 percent compared to the previous model which used fastText.") |
|
st.markdown("Original file is located at") |
|
st.markdown("https://colab.research.google.com/drive/1QuorqAuLsmomesZHsaQHEZgzbPEM8YTH") |
|
|
|
import torch |
|
import numpy as np |
|
from transformers import AutoTokenizer |
|
tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-turkish-uncased") |
|
from transformers import AutoModel |
|
model = BertForSequenceClassification.from_pretrained("NimaKL/spamd_model") |
|
token_id = [] |
|
attention_masks = [] |
|
def preprocessing(input_text, tokenizer): |
|
''' |
|
Returns <class transformers.tokenization_utils_base.BatchEncoding> with the following fields: |
|
- input_ids: list of token ids |
|
- token_type_ids: list of token type ids |
|
- attention_mask: list of indices (0,1) specifying which tokens should considered by the model (return_attention_mask = True). |
|
''' |
|
return tokenizer.encode_plus( |
|
input_text, |
|
add_special_tokens = True, |
|
max_length = 32, |
|
pad_to_max_length = True, |
|
return_attention_mask = True, |
|
return_tensors = 'pt' |
|
) |
|
device = 'cpu' |
|
|
|
def predict(new_sentence): |
|
|
|
test_ids = [] |
|
test_attention_mask = [] |
|
|
|
encoding = preprocessing(new_sentence, tokenizer) |
|
|
|
test_ids.append(encoding['input_ids']) |
|
test_attention_mask.append(encoding['attention_mask']) |
|
test_ids = torch.cat(test_ids, dim = 0) |
|
test_attention_mask = torch.cat(test_attention_mask, dim = 0) |
|
|
|
with torch.no_grad(): |
|
output = model(test_ids.to(device), token_type_ids = None, attention_mask = test_attention_mask.to(device)) |
|
prediction = 'Spam' if np.argmax(output.logits.cpu().numpy()).flatten().item() == 1 else 'Normal' |
|
pred = 'Predicted Class: '+ prediction |
|
return pred |
|
|
|
if text or aButton: |
|
with col2: |
|
with st.spinner('Wait for it...'): |
|
st.success(predict(text)) |