## LIBRARIES ### ## Data import numpy as np import pandas as pd import torch import json from tqdm import tqdm from math import floor from datasets import load_dataset from collections import defaultdict from transformers import AutoTokenizer pd.options.display.float_format = '${:,.2f}'.format # Analysis # from gensim.models.doc2vec import Doc2Vec # from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score import nltk from nltk.cluster import KMeansClusterer import scipy.spatial.distance as sdist from scipy.spatial import distance_matrix # nltk.download('punkt') #make sure that punkt is downloaded # App & Visualization import streamlit as st import altair as alt import plotly.graph_objects as go from streamlit_vega_lite import altair_component # utils from random import sample from error_analysis import utils as ut def down_samp(embedding): """Down sample a data frame for altiar visualization """ # total number of positive and negative sentiments in the class #embedding = embedding.groupby('slice').apply(lambda x: x.sample(frac=0.3)) total_size = embedding.groupby(['slice','label'], as_index=False).count() user_data = 0 # if 'Your Sentences' in str(total_size['slice']): # tmp = embedding.groupby(['slice'], as_index=False).count() # val = int(tmp[tmp['slice'] == "Your Sentences"]['source']) # user_data = val max_sample = total_size.groupby('slice').max()['content'] # # down sample to meeting altair's max values # # but keep the proportional representation of groups down_samp = 1/(sum(max_sample.astype(float))/(1000-user_data)) max_samp = max_sample.apply(lambda x: floor(x*down_samp)).astype(int).to_dict() max_samp['Your Sentences'] = user_data # # sample down for each group in the data frame embedding = embedding.groupby('slice').apply(lambda x: x.sample(n=max_samp.get(x.name))).reset_index(drop=True) # # order the embedding return(embedding) def data_comparison(df): selection = alt.selection_multi(fields=['cluster:N','label:O']) color = alt.condition(alt.datum.slice == 'high-loss', alt.Color('cluster:N', scale = alt.Scale(domain=df.cluster.unique().tolist())), alt.value("lightgray")) opacity = alt.condition(selection, alt.value(0.7), alt.value(0.25)) # basic chart scatter = alt.Chart(df).mark_point(size=100, filled=True).encode( x=alt.X('x:Q', axis=None), y=alt.Y('y:Q', axis=None), color=color, shape=alt.Shape('label:O', scale=alt.Scale(range=['circle', 'diamond'])), tooltip=['cluster:N','slice:N','content:N','label:O','pred:O'], opacity=opacity ).properties( width=1000, height=800 ).interactive() legend = alt.Chart(df).mark_point(size=100, filled=True).encode( x=alt.X("label:O"), y=alt.Y('cluster:N', axis=alt.Axis(orient='right'), title=""), shape=alt.Shape('label:O', scale=alt.Scale( range=['circle', 'diamond']), legend=None), color=color, ).add_selection( selection ) layered = scatter | legend layered = layered.configure_axis( grid=False ).configure_view( strokeOpacity=0 ) return layered def quant_panel(embedding_df): """ Quantitative Panel Layout""" all_metrics = {} st.warning("**Error slice visualization**") with st.expander("How to read this chart:"): st.markdown("* Each **point** is an input example.") st.markdown("* Gray points have low-loss and the colored have high-loss. High-loss instances are clustered using **kmeans** and each color represents a cluster.") st.markdown("* The **shape** of each point reflects the label category -- positive (diamond) or negative sentiment (circle).") st.altair_chart(data_comparison(down_samp(embedding_df)), use_container_width=True) def frequent_tokens(data, tokenizer, loss_quantile=0.95, top_k=200, smoothing=0.005): unique_tokens = [] tokens = [] for row in tqdm(data['content']): tokenized = tokenizer(row,padding=True, return_tensors='pt') tokens.append(tokenized['input_ids'].flatten()) unique_tokens.append(torch.unique(tokenized['input_ids'])) losses = data['loss'].astype(float) high_loss = losses.quantile(loss_quantile) loss_weights = (losses > high_loss) loss_weights = loss_weights / loss_weights.sum() token_frequencies = defaultdict(float) token_frequencies_error = defaultdict(float) weights_uniform = np.full_like(loss_weights, 1 / len(loss_weights)) num_examples = len(data) for i in tqdm(range(num_examples)): for token in unique_tokens[i]: token_frequencies[token.item()] += weights_uniform[i] token_frequencies_error[token.item()] += loss_weights[i] token_lrs = {k: (smoothing+token_frequencies_error[k]) / (smoothing+token_frequencies[k]) for k in token_frequencies} tokens_sorted = list(map(lambda x: x[0], sorted(token_lrs.items(), key=lambda x: x[1])[::-1])) top_tokens = [] for i, (token) in enumerate(tokens_sorted[:top_k]): top_tokens.append(['%10s' % (tokenizer.decode(token)), '%.4f' % (token_frequencies[token]), '%.4f' % ( token_frequencies_error[token]), '%4.2f' % (token_lrs[token])]) return pd.DataFrame(top_tokens, columns=['Token', 'Freq', 'Freq error slice', 'lrs']) @st.cache(ttl=600) def get_data(inference, emb): preds = inference.outputs.numpy() losses = inference.losses.numpy() embeddings = pd.DataFrame(emb, columns=['x', 'y']) num_examples = len(losses) # dataset_labels = [dataset[i]['label'] for i in range(num_examples)] return pd.concat([pd.DataFrame(np.transpose(np.vstack([dataset[:num_examples]['content'], dataset[:num_examples]['label'], preds, losses])), columns=['content', 'label', 'pred', 'loss']), embeddings], axis=1) def clustering(data,num_clusters): X = np.array(data['embedding'].tolist()) kclusterer = KMeansClusterer( num_clusters, distance=nltk.cluster.util.cosine_distance, repeats=25,avoid_empty_clusters=True) assigned_clusters = kclusterer.cluster(X, assign_clusters=True) data['cluster'] = pd.Series(assigned_clusters, index=data.index).astype('int') data['centroid'] = data['cluster'].apply(lambda x: kclusterer.means()[x]) return data, assigned_clusters def kmeans(df, num_clusters=3): data_hl = df.loc[df['slice'] == 'high-loss'] data_kmeans,clusters = clustering(data_hl,num_clusters) merged = pd.merge(df, data_kmeans, left_index=True, right_index=True, how='outer', suffixes=('', '_y')) merged.drop(merged.filter(regex='_y$').columns.tolist(),axis=1,inplace=True) merged['cluster'] = merged['cluster'].fillna(num_clusters).astype('int') return merged def distance_from_centroid(row): return sdist.norm(row['embedding'] - row['centroid'].tolist()) @st.cache(ttl=600) def topic_distribution(weights, smoothing=0.01): topic_frequencies = defaultdict(float) topic_frequencies_spotlight = defaultdict(float) weights_uniform = np.full_like(weights, 1 / len(weights)) num_examples = len(weights) for i in range(num_examples): example = dataset[i] category = example['title'] topic_frequencies[category] += weights_uniform[i] topic_frequencies_spotlight[category] += weights[i] topic_ratios = {c: (smoothing + topic_frequencies_spotlight[c]) / ( smoothing + topic_frequencies[c]) for c in topic_frequencies} categories_sorted = map(lambda x: x[0], sorted( topic_ratios.items(), key=lambda x: x[1], reverse=True)) topic_distr = [] for category in categories_sorted: topic_distr.append(['%.3f' % topic_frequencies[category], '%.3f' % topic_frequencies_spotlight[category], '%.2f' % topic_ratios[category], '%s' % category]) return pd.DataFrame(topic_distr, columns=['Overall frequency', 'Error frequency', 'Ratio', 'Category']) # for category in categories_sorted: # return(topic_frequencies[category], topic_frequencies_spotlight[category], topic_ratios[category], category) def populate_session(dataset,model): data_df = read_file_to_df('./assets/data/'+dataset+ '_'+ model+'.parquet') if model == 'albert-base-v2-yelp-polarity': tokenizer = AutoTokenizer.from_pretrained('textattack/'+model) else: tokenizer = AutoTokenizer.from_pretrained(model) if "user_data" not in st.session_state: st.session_state["user_data"] = data_df if "selected_slice" not in st.session_state: st.session_state["selected_slice"] = None @st.cache(allow_output_mutation=True) def read_file_to_df(file): return pd.read_parquet(file) if __name__ == "__main__": ### STREAMLIT APP CONGFIG ### st.set_page_config(layout="wide", page_title="Interactive Error Analysis") ut.init_style() lcol, rcol = st.columns([2, 2]) # ******* loading the mode and the data #st.sidebar.mardown("