Jhsmit's picture
switch to polars
d1c70f0
from __future__ import annotations
import base64
import gzip
import json
from dataclasses import dataclass, fields
from io import BytesIO
from pathlib import Path
from urllib.parse import parse_qsl
import altair as alt
import ipywidgets as widgets
import numpy as np
import polars as pl
import solara
import solara.lab
from cmap import Colormap
from ipymolstar.widget import PDBeMolstar
from pydantic import BaseModel
from make_link import decode_data
base_v = np.vectorize(np.base_repr)
PAD_SIZE = 0.05 # when not autoscale Y size of padding used
def norm(x, vmin, vmax):
return (x - vmin) / (vmax - vmin)
class ColorTransform(BaseModel):
name: str = "tol:rainbow_PuRd"
norm_type: str = "linear"
vmin: float = 0.0
vmax: float = 1.0
missing_data_color: str = "#8c8c8c"
highlight_color: str = "#e933f8"
def molstar_colors(self, data: pl.DataFrame) -> dict:
data = data.drop_nulls()
if self.norm_type == "categorical":
values = data["value"]
else:
values = norm(data["value"], vmin=self.vmin, vmax=self.vmax)
rgba_array = self.cmap(values, bytes=True)
ints = rgba_array.astype(np.uint8).view(dtype=np.uint32).byteswap()
padded = np.char.rjust(base_v(ints // 2**8, 16), 6, "0")
hex_colors = np.char.add("#", padded).squeeze()
color_data = {
"data": [
{"residue_number": resi, "color": hcolor.lower()}
for resi, hcolor in zip(data["residue_number"], hex_colors)
],
"nonSelectedColor": self.missing_data_color,
}
return color_data
@property
def cmap(self) -> Colormap:
return Colormap(self.name, bad=self.missing_data_color)
@property
def altair_scale(self) -> alt.Scale:
if self.norm_type == "categorical":
colors = self.cmap.to_altair(N=self.cmap.num_colors)
domain = range(self.cmap.num_colors)
else:
colors = self.cmap.to_altair()
domain = np.linspace(self.vmin, self.vmax, 256, endpoint=True)
scale = alt.Scale(domain=list(domain), range=colors, clamp=True)
return scale
class AxisProperties(BaseModel):
label: str = "x"
unit: str = "au"
autoscale_y: bool = True
@property
def title(self) -> str:
return f"{self.label} ({self.unit})"
def make_chart(
data: pl.DataFrame, colors: ColorTransform, axis_properties: AxisProperties
) -> alt.LayerChart:
xmin, xmax = data["residue_number"].min(), data["residue_number"].max()
xpad = (xmax - xmin) * 0.05
xscale = alt.Scale(domain=(xmin - xpad, xmax + xpad))
if axis_properties.autoscale_y:
y_scale = alt.Scale()
elif colors.norm_type == "categorical":
ypad = colors.cmap.num_colors * 0.05
y_scale = alt.Scale(domain=(0 - ypad, colors.cmap.num_colors - 1 + ypad))
else:
ypad = (colors.vmax - colors.vmin) * 0.05
y_scale = alt.Scale(domain=(colors.vmin - ypad, colors.vmax + ypad))
zoom_x = alt.selection_interval(
bind="scales",
encodings=["x"],
zoom="wheel![!event.shiftKey]",
)
scatter = (
alt.Chart(data)
.mark_circle(interpolate="basis", size=200)
.encode(
x=alt.X("residue_number:Q", title="Residue Number", scale=xscale),
y=alt.Y(
"value:Q",
title=axis_properties.title,
scale=y_scale,
),
color=alt.Color(
f"value:{'O' if colors.norm_type == 'categorical' else 'Q'}",
scale=colors.altair_scale,
title=axis_properties.title,
),
)
.add_params(zoom_x)
)
# Create a selection that chooses the nearest point & selects based on x-value
nearest = alt.selection_point(
name="point",
nearest=True,
on="pointerover",
fields=["residue_number"],
empty=False,
clear="mouseout",
)
select_residue = (
alt.Chart(data)
.mark_point()
.encode(
x="residue_number:Q",
opacity=alt.value(0),
)
.add_params(nearest)
)
# Draw a rule at the location of the selection
rule = (
alt.Chart(data)
.mark_rule(color=colors.highlight_color, size=2)
.encode(
x="residue_number:Q",
)
.transform_filter(nearest)
)
# vline = (
# alt.Chart(pd.DataFrame({"x": [0]}))
# .mark_rule(color=colors.highlight_color, size=2)
# .encode(x="x:Q")
# )
line_position = alt.param(name="line_position", value=0.0)
line_opacity = alt.param(name="line_opacity", value=1)
df_line = pl.DataFrame({"x": [1.0]})
# Create vertical rule with parameter
vline = (
alt.Chart(df_line)
.mark_rule(color=colors.highlight_color, opacity=line_opacity, size=2)
.encode(x=alt.X("p", type="quantitative"))
.transform_calculate(p=alt.datum.x * line_position)
.add_params(line_position, line_opacity)
)
# Put the five layers into a chart and bind the data
chart = (
alt.layer(scatter, vline, select_residue, rule).properties(
width="container",
height=480, # autosize height?
)
# .configure(autosize="fit")
)
return chart
@solara.component
def ScatterChart(
data: pl.DataFrame,
colors: ColorTransform,
axis_properties: AxisProperties,
on_selections,
line_value,
):
def mem_chart():
chart = make_chart(data, colors, axis_properties)
return chart
chart = solara.use_memo(mem_chart, dependencies=[data, colors, axis_properties])
if line_value is not None:
params = {"line_position": line_value, "line_opacity": 1}
else:
params = {"line_position": 0.0, "line_opacity": 0}
dark_effective = solara.lab.use_dark_effective()
if dark_effective:
options = {"actions": False, "theme": "dark"}
else:
options = {"actions": False}
view = alt.JupyterChart.element( # type: ignore
chart=chart,
embed_options=options,
_params=params,
)
def bind():
real = solara.get_widget(view)
real.selections.observe(on_selections, "point") # type: ignore
solara.use_effect(bind, [data, colors])
def is_numeric(val) -> bool:
if val is not None:
return not np.isnan(val)
return False
@solara.component
def ProteinView(
title: str,
molecule_id: str,
data: pl.DataFrame,
colors: ColorTransform,
axis_properties: AxisProperties,
dark_effective: bool,
description: str = "",
):
about_dialog = solara.use_reactive(False)
fullscreen = solara.use_reactive(False)
# residue number to highlight in altair chart
line_number = solara.use_reactive(None)
# residue number to highlight in protein view
highlight_number = solara.use_reactive(None)
if data.is_empty():
color_data = {}
else:
color_data = colors.molstar_colors(data)
tooltips = {
"data": [
{
"residue_number": resi,
"tooltip": f"{axis_properties.label}: {value:.2g} {axis_properties.unit}"
if is_numeric(value)
else "No data",
}
for resi, value in zip(data["residue_number"], data["value"])
]
}
def on_molstar_mouseover(value):
r = value.get("residueNumber", None)
line_number.set(r)
def on_molstar_mouseout(value):
on_molstar_mouseover({})
def on_chart_selection(event):
try:
r = event["new"].value[0]["residue_number"]
highlight_number.set(r)
except (IndexError, KeyError):
highlight_number.set(None)
with solara.AppBar():
solara.AppBarTitle(title)
with solara.Tooltip("Fullscreen"):
solara.Button(
icon_name="mdi-fullscreen",
icon=True,
on_click=lambda: fullscreen.set(not fullscreen.value),
)
if description:
with solara.Tooltip("About"):
solara.Button(
icon_name="mdi-information-outline",
icon=True,
on_click=lambda: about_dialog.set(True),
)
solara.lab.ThemeToggle()
with solara.v.Dialog(
v_model=about_dialog.value, on_v_model=lambda _ignore: about_dialog.set(False)
):
with solara.Card("About", margin=0):
solara.Markdown(description)
with solara.ColumnsResponsive([4, 8]):
with solara.Card(style={"height": "550px"}):
PDBeMolstar.element( # type: ignore
theme="dark" if dark_effective else "light",
molecule_id=molecule_id.lower(),
color_data=color_data,
hide_water=True,
tooltips=tooltips,
height="525px",
highlight={"data": [{"residue_number": int(highlight_number.value)}]}
if highlight_number.value
else None,
highlight_color=colors.highlight_color,
on_mouseover_event=on_molstar_mouseover,
on_mouseout_event=on_molstar_mouseout,
hide_controls_icon=True,
hide_expand_icon=True,
hide_settings_icon=True,
expanded=fullscreen.value,
).key(f"molstar-{dark_effective}")
if not fullscreen.value:
with solara.Card(style={"height": "550px"}):
if data.is_empty():
solara.Text("No data")
else:
ScatterChart(
data,
colors,
axis_properties,
on_chart_selection,
line_number.value,
)
@solara.component
def RoutedView():
route = solara.use_router()
dark_effective = solara.lab.use_dark_effective()
try:
query_dict = {k: v for k, v in parse_qsl(route.search)}
colors = ColorTransform(**query_dict) # type: ignore
axis_properties = AxisProperties(**query_dict) # type: ignore
data = decode_data(query_dict["data"])
ProteinView(
query_dict["title"],
molecule_id=query_dict["molecule_id"],
data=data,
colors=colors,
axis_properties=axis_properties,
dark_effective=dark_effective,
description=query_dict.get("description", ""),
)
except KeyError as err:
solara.Warning(f"Error: {err}")