import streamlit as st import pandas as pd from pipelines.keyphrase_extraction_pipeline import KeyphraseExtractionPipeline from pipelines.keyphrase_generation_pipeline import KeyphraseGenerationPipeline import orjson from annotated_text.util import get_annotated_html from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode import re import string import numpy as np @st.cache(allow_output_mutation=True, show_spinner=False) def load_pipeline(chosen_model): if "keyphrase-extraction" in chosen_model: return KeyphraseExtractionPipeline(chosen_model) elif "keyphrase-generation" in chosen_model: return KeyphraseGenerationPipeline(chosen_model) def extract_keyphrases(): st.session_state.keyphrases = pipe(st.session_state.input_text) st.session_state.data_frame = pd.concat( [ st.session_state.data_frame, pd.DataFrame( data=[ np.concatenate( ( [ st.session_state.chosen_model, st.session_state.input_text, ], st.session_state.keyphrases, ) ) ], columns=["model", "text"] + [str(i) for i in range(len(st.session_state.keyphrases))], ), ], ignore_index=True, axis=0, ).fillna("") def get_annotated_text(text, keyphrases): for keyphrase in keyphrases: text = re.sub( rf"({keyphrase})([^A-Za-z])", rf"$K:{keyphrases.index(keyphrase)}\2", text, flags=re.I, count=1 ) result = [] for i, word in enumerate(text.split(" ")): if "$K" in word and re.search( "(\d+)$", word.translate(str.maketrans("", "", string.punctuation)) ): result.append( ( re.sub( r"\$K:\d+", keyphrases[ int( re.search( "(\d+)$", word.translate( str.maketrans("", "", string.punctuation) ), ).group(1) ) ], word, ), "KEY", "#21c354", ) ) else: if i == len(st.session_state.input_text.split(" ")) - 1: result.append(f" {word}") elif i == 0: result.append(f"{word} ") else: result.append(f" {word} ") return result def rerender_output(layout): layout.subheader("🐧 Output") if ( len(st.session_state.keyphrases) > 0 and len(st.session_state.selected_rows) == 0 ): text, keyphrases = st.session_state.input_text, st.session_state.keyphrases else: text, keyphrases = ( st.session_state.selected_rows["text"].values[0], [ keyphrase for keyphrase in st.session_state.selected_rows.loc[ :, st.session_state.selected_rows.columns.difference( ["model", "text"] ), ] .astype(str) .values.tolist()[0] if keyphrase != "" ], ) result = get_annotated_text(text, list(keyphrases)) layout.markdown( get_annotated_html(*result), unsafe_allow_html=True, ) if "generation" in st.session_state.chosen_model: abstractive_keyphrases = [ keyphrase for keyphrase in keyphrases if keyphrase.lower() not in text.lower() ] layout.write(", ".join(abstractive_keyphrases)) if "config" not in st.session_state: with open("config.json", "r") as f: content = f.read() st.session_state.config = orjson.loads(content) st.session_state.data_frame = pd.DataFrame(columns=["model"]) st.session_state.keyphrases = [] if "select_rows" not in st.session_state: st.session_state.selected_rows = [] st.set_page_config( page_icon="🔑", page_title="Keyphrase extraction/generation with Transformers", layout="wide", ) st.header("🔑 Keyphrase extraction/generation with Transformers") col1, col2 = st.columns(2) chosen_model = col1.selectbox( "Choose your model:", st.session_state.config.get("models"), ) st.session_state.chosen_model = chosen_model with st.spinner("Loading pipeline..."): pipe = load_pipeline( f"{st.session_state.config.get('model_author')}/{st.session_state.chosen_model}" ) st.session_state.input_text = col1.text_area( "Input", st.session_state.config.get("example_text"), height=300 ).replace("\n", " ") with st.spinner("Extracting keyphrases..."): pressed = col1.button("Extract", on_click=extract_keyphrases) if len(st.session_state.data_frame.columns) > 0: st.subheader("📜 History") builder = GridOptionsBuilder.from_dataframe( st.session_state.data_frame, sortable=False ) builder.configure_selection(selection_mode="single", use_checkbox=True) builder.configure_column("text", hide=True) go = builder.build() data = AgGrid( st.session_state.data_frame, gridOptions=go, update_mode=GridUpdateMode.SELECTION_CHANGED, ) st.session_state.selected_rows = pd.DataFrame(data["selected_rows"]) if len(st.session_state.selected_rows) > 0 or len(st.session_state.keyphrases) > 0: rerender_output(col2)