from transformers import FSMTForConditionalGeneration, FSMTTokenizer from transformers import AutoModelForSequenceClassification from lxml_html_clean import Cleaner from transformers import AutoTokenizer from langdetect import detect from newspaper import Article from PIL import Image import streamlit as st import requests import torch st.markdown("## Prediction of Misinformation by given URL") background = Image.open('logo.jpg') st.image(background) st.markdown(f"### Article URL") text = st.text_area("Insert some url here", value="https://www.livelaw.in/news-updates/supreme-court-collegium-recommends-appointment-advocate-praveen-kumar-giri-judge-allahabad-high-court-279470") # @st.cache(allow_output_mutation=True) # def get_models_and_tokenizers(): # model_name = 'distilbert-base-uncased-finetuned-sst-2-english' # model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2) # model.eval() # tokenizer = AutoTokenizer.from_pretrained(model_name) # model.load_state_dict(torch.load('./my_saved_model/checkpoint-6320/rng_state.pth', map_location='cpu')) # model_name_translator = "facebook/wmt19-ru-en" # tokenizer_translator = FSMTTokenizer.from_pretrained(model_name_translator) # model_translator = FSMTForConditionalGeneration.from_pretrained(model_name_translator) # model_translator.eval() # return model, tokenizer, model_translator, tokenizer_translator @st.cache_data() def get_models_and_tokenizers(): model_name = 'distilbert-base-uncased-finetuned-sst-2-english' checkpoint_dir = './my_saved_model/checkpoint-6320/' # Path to your checkpoint folder # Load the classification model and tokenizer model = AutoModelForSequenceClassification.from_pretrained(checkpoint_dir, num_labels=2) tokenizer = AutoTokenizer.from_pretrained(model_name) # Load the translator model and tokenizer model_name_translator = "facebook/wmt19-ru-en" tokenizer_translator = FSMTTokenizer.from_pretrained(model_name_translator) model_translator = FSMTForConditionalGeneration.from_pretrained(model_name_translator) model.eval() model_translator.eval() return model, tokenizer, model_translator, tokenizer_translator model, tokenizer, model_translator, tokenizer_translator = get_models_and_tokenizers() article = Article(text) article.download() article.parse() concated_text = article.title + '. ' + article.text lang = detect(concated_text) st.markdown(f"### Language detection") if lang == 'ru': st.markdown(f"The language of this article is {lang.upper()} so we translated it!") with st.spinner('Waiting for translation'): input_ids = tokenizer_translator.encode(concated_text, return_tensors="pt", max_length=512, truncation=True) outputs = model_translator.generate(input_ids) decoded = tokenizer_translator.decode(outputs[0], skip_special_tokens=True) st.markdown("### Translated Text") st.markdown(f"{decoded[:777]}") concated_text = decoded else: st.markdown(f"The language of this article for sure: {lang.upper()}!") st.markdown("### Extracted Text") st.markdown(f"{concated_text[:777]}") tokens_info = tokenizer(concated_text, truncation=True, return_tensors="pt") with torch.no_grad(): raw_predictions = model(**tokens_info) softmaxed = int(torch.nn.functional.softmax(raw_predictions.logits[0], dim=0)[1] * 100) st.markdown("### Truthteller Predicts..") st.progress(softmaxed) st.markdown(f"This is fake by *{softmaxed}%*!") if (softmaxed > 70): st.error('We would not trust this text! This is misleading..') elif (softmaxed > 40): st.warning('We are not sure about this text!') else: st.success('We would trust this text!')