"""Plotting utilities.""" import numpy as np from typing import Tuple from bokeh.layouts import column from bokeh.models import CustomJS, Slider from bokeh.plotting import figure, Figure, ColumnDataSource from bokeh.embed import components def barplot(attended: np.ndarray, weights: np.ndarray) -> Figure: """ Bokeh barplot showing top k attention weights. k is interactively changable via a slider. Args: attended (np.ndarray): Names of the attended entities weights (np.ndarray): Attention weights Returns: bokeh.plotting.Figure: Can be visualized for debugging, via bokeh.plotting (i.e. output_file, show) """ K = 4 # reset from slider callback source = ColumnDataSource( data=dict(attended=attended, weights=weights), ) top_k_slider = Slider(start=1, end=len(attended), value=K, step=1, title="k") p = figure( x_range=source.data["attended"][:K], # adapted by callback plot_height=350, title="Top k Gene Attention Weights", toolbar_location="below", tools="pan,wheel_zoom,box_zoom,save,reset", ) p.vbar(x="attended", top="weights", source=source, width=0.9) # define the callback callback = CustomJS( args=dict( source=source, xrange=p.x_range, yrange=p.y_range, attended=attended, weights=weights, top_k=top_k_slider, ), code=""" var data = source.data; const k = top_k.value; data['attended'] = attended.slice(0, k) data['weights'] = weights.slice(0, k) source.change.emit(); // not need if data is in descending order var yrange_arr = data['weights']; var yrange_max = Math.max(...yrange_arr) * 1.05; yrange.end = yrange_max; xrange.factors = data['attended']; source.change.emit(); """, ) top_k_slider.js_on_change("value", callback) layout = column(top_k_slider, p) p.xgrid.grid_line_color = None p.y_range.start = 0 return layout def embed_barplot(attended: np.ndarray, weights: np.ndarray) -> Tuple[str, str]: """Bokeh barplot showing top k attention weights. k is interactively changable via a slider. Args: attended (np.ndarray): Names of the attended entities weights (np.ndarray): Attention weights Returns: Tuple[str, str]: javascript and html """ return components(barplot(attended, weights))