|
from __future__ import annotations |
|
|
|
import base64 |
|
import gzip |
|
import json |
|
from dataclasses import dataclass, fields |
|
from io import BytesIO |
|
from pathlib import Path |
|
from urllib.parse import parse_qsl |
|
|
|
import altair as alt |
|
import ipywidgets as widgets |
|
import numpy as np |
|
import polars as pl |
|
import solara |
|
import solara.lab |
|
from cmap import Colormap |
|
from ipymolstar.widget import PDBeMolstar |
|
from pydantic import BaseModel |
|
|
|
from make_link import decode_data |
|
|
|
base_v = np.vectorize(np.base_repr) |
|
PAD_SIZE = 0.05 |
|
|
|
|
|
def norm(x, vmin, vmax): |
|
return (x - vmin) / (vmax - vmin) |
|
|
|
|
|
class ColorTransform(BaseModel): |
|
name: str = "tol:rainbow_PuRd" |
|
norm_type: str = "linear" |
|
vmin: float = 0.0 |
|
vmax: float = 1.0 |
|
missing_data_color: str = "#8c8c8c" |
|
highlight_color: str = "#e933f8" |
|
|
|
def molstar_colors(self, data: pl.DataFrame) -> dict: |
|
data = data.drop_nulls() |
|
if self.norm_type == "categorical": |
|
values = data["value"] |
|
else: |
|
values = norm(data["value"], vmin=self.vmin, vmax=self.vmax) |
|
|
|
rgba_array = self.cmap(values, bytes=True) |
|
ints = rgba_array.astype(np.uint8).view(dtype=np.uint32).byteswap() |
|
padded = np.char.rjust(base_v(ints // 2**8, 16), 6, "0") |
|
hex_colors = np.char.add("#", padded).squeeze() |
|
|
|
color_data = { |
|
"data": [ |
|
{"residue_number": resi, "color": hcolor.lower()} |
|
for resi, hcolor in zip(data["residue_number"], hex_colors) |
|
], |
|
"nonSelectedColor": self.missing_data_color, |
|
} |
|
|
|
return color_data |
|
|
|
@property |
|
def cmap(self) -> Colormap: |
|
return Colormap(self.name, bad=self.missing_data_color) |
|
|
|
@property |
|
def altair_scale(self) -> alt.Scale: |
|
if self.norm_type == "categorical": |
|
colors = self.cmap.to_altair(N=self.cmap.num_colors) |
|
domain = range(self.cmap.num_colors) |
|
else: |
|
colors = self.cmap.to_altair() |
|
domain = np.linspace(self.vmin, self.vmax, 256, endpoint=True) |
|
|
|
scale = alt.Scale(domain=list(domain), range=colors, clamp=True) |
|
return scale |
|
|
|
|
|
class AxisProperties(BaseModel): |
|
label: str = "x" |
|
unit: str = "au" |
|
autoscale_y: bool = True |
|
|
|
@property |
|
def title(self) -> str: |
|
return f"{self.label} ({self.unit})" |
|
|
|
|
|
def make_chart( |
|
data: pl.DataFrame, colors: ColorTransform, axis_properties: AxisProperties |
|
) -> alt.LayerChart: |
|
xmin, xmax = data["residue_number"].min(), data["residue_number"].max() |
|
xpad = (xmax - xmin) * 0.05 |
|
xscale = alt.Scale(domain=(xmin - xpad, xmax + xpad)) |
|
|
|
if axis_properties.autoscale_y: |
|
y_scale = alt.Scale() |
|
elif colors.norm_type == "categorical": |
|
ypad = colors.cmap.num_colors * 0.05 |
|
y_scale = alt.Scale(domain=(0 - ypad, colors.cmap.num_colors - 1 + ypad)) |
|
else: |
|
ypad = (colors.vmax - colors.vmin) * 0.05 |
|
y_scale = alt.Scale(domain=(colors.vmin - ypad, colors.vmax + ypad)) |
|
|
|
zoom_x = alt.selection_interval( |
|
bind="scales", |
|
encodings=["x"], |
|
zoom="wheel![!event.shiftKey]", |
|
) |
|
|
|
scatter = ( |
|
alt.Chart(data) |
|
.mark_circle(interpolate="basis", size=200) |
|
.encode( |
|
x=alt.X("residue_number:Q", title="Residue Number", scale=xscale), |
|
y=alt.Y( |
|
"value:Q", |
|
title=axis_properties.title, |
|
scale=y_scale, |
|
), |
|
color=alt.Color( |
|
f"value:{'O' if colors.norm_type == 'categorical' else 'Q'}", |
|
scale=colors.altair_scale, |
|
title=axis_properties.title, |
|
), |
|
) |
|
.add_params(zoom_x) |
|
) |
|
|
|
|
|
nearest = alt.selection_point( |
|
name="point", |
|
nearest=True, |
|
on="pointerover", |
|
fields=["residue_number"], |
|
empty=False, |
|
clear="mouseout", |
|
) |
|
|
|
select_residue = ( |
|
alt.Chart(data) |
|
.mark_point() |
|
.encode( |
|
x="residue_number:Q", |
|
opacity=alt.value(0), |
|
) |
|
.add_params(nearest) |
|
) |
|
|
|
|
|
rule = ( |
|
alt.Chart(data) |
|
.mark_rule(color=colors.highlight_color, size=2) |
|
.encode( |
|
x="residue_number:Q", |
|
) |
|
.transform_filter(nearest) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
line_position = alt.param(name="line_position", value=0.0) |
|
line_opacity = alt.param(name="line_opacity", value=1) |
|
df_line = pl.DataFrame({"x": [1.0]}) |
|
|
|
|
|
vline = ( |
|
alt.Chart(df_line) |
|
.mark_rule(color=colors.highlight_color, opacity=line_opacity, size=2) |
|
.encode(x=alt.X("p", type="quantitative")) |
|
.transform_calculate(p=alt.datum.x * line_position) |
|
.add_params(line_position, line_opacity) |
|
) |
|
|
|
|
|
chart = ( |
|
alt.layer(scatter, vline, select_residue, rule).properties( |
|
width="container", |
|
height=480, |
|
) |
|
|
|
) |
|
|
|
return chart |
|
|
|
|
|
@solara.component |
|
def ScatterChart( |
|
data: pl.DataFrame, |
|
colors: ColorTransform, |
|
axis_properties: AxisProperties, |
|
on_selections, |
|
line_value, |
|
): |
|
def mem_chart(): |
|
chart = make_chart(data, colors, axis_properties) |
|
return chart |
|
|
|
chart = solara.use_memo(mem_chart, dependencies=[data, colors, axis_properties]) |
|
|
|
if line_value is not None: |
|
params = {"line_position": line_value, "line_opacity": 1} |
|
else: |
|
params = {"line_position": 0.0, "line_opacity": 0} |
|
dark_effective = solara.lab.use_dark_effective() |
|
if dark_effective: |
|
options = {"actions": False, "theme": "dark"} |
|
else: |
|
options = {"actions": False} |
|
|
|
view = alt.JupyterChart.element( |
|
chart=chart, |
|
embed_options=options, |
|
_params=params, |
|
) |
|
|
|
def bind(): |
|
real = solara.get_widget(view) |
|
real.selections.observe(on_selections, "point") |
|
|
|
solara.use_effect(bind, [data, colors]) |
|
|
|
|
|
def is_numeric(val) -> bool: |
|
if val is not None: |
|
return not np.isnan(val) |
|
return False |
|
|
|
|
|
@solara.component |
|
def ProteinView( |
|
title: str, |
|
molecule_id: str, |
|
data: pl.DataFrame, |
|
colors: ColorTransform, |
|
axis_properties: AxisProperties, |
|
dark_effective: bool, |
|
description: str = "", |
|
): |
|
about_dialog = solara.use_reactive(False) |
|
fullscreen = solara.use_reactive(False) |
|
|
|
|
|
line_number = solara.use_reactive(None) |
|
|
|
|
|
highlight_number = solara.use_reactive(None) |
|
|
|
if data.is_empty(): |
|
color_data = {} |
|
else: |
|
color_data = colors.molstar_colors(data) |
|
|
|
tooltips = { |
|
"data": [ |
|
{ |
|
"residue_number": resi, |
|
"tooltip": f"{axis_properties.label}: {value:.2g} {axis_properties.unit}" |
|
if is_numeric(value) |
|
else "No data", |
|
} |
|
for resi, value in zip(data["residue_number"], data["value"]) |
|
] |
|
} |
|
|
|
def on_molstar_mouseover(value): |
|
r = value.get("residueNumber", None) |
|
line_number.set(r) |
|
|
|
def on_molstar_mouseout(value): |
|
on_molstar_mouseover({}) |
|
|
|
def on_chart_selection(event): |
|
try: |
|
r = event["new"].value[0]["residue_number"] |
|
highlight_number.set(r) |
|
except (IndexError, KeyError): |
|
highlight_number.set(None) |
|
|
|
with solara.AppBar(): |
|
solara.AppBarTitle(title) |
|
with solara.Tooltip("Fullscreen"): |
|
solara.Button( |
|
icon_name="mdi-fullscreen", |
|
icon=True, |
|
on_click=lambda: fullscreen.set(not fullscreen.value), |
|
) |
|
if description: |
|
with solara.Tooltip("About"): |
|
solara.Button( |
|
icon_name="mdi-information-outline", |
|
icon=True, |
|
on_click=lambda: about_dialog.set(True), |
|
) |
|
solara.lab.ThemeToggle() |
|
|
|
with solara.v.Dialog( |
|
v_model=about_dialog.value, on_v_model=lambda _ignore: about_dialog.set(False) |
|
): |
|
with solara.Card("About", margin=0): |
|
solara.Markdown(description) |
|
|
|
with solara.ColumnsResponsive([4, 8]): |
|
with solara.Card(style={"height": "550px"}): |
|
PDBeMolstar.element( |
|
theme="dark" if dark_effective else "light", |
|
molecule_id=molecule_id.lower(), |
|
color_data=color_data, |
|
hide_water=True, |
|
tooltips=tooltips, |
|
height="525px", |
|
highlight={"data": [{"residue_number": int(highlight_number.value)}]} |
|
if highlight_number.value |
|
else None, |
|
highlight_color=colors.highlight_color, |
|
on_mouseover_event=on_molstar_mouseover, |
|
on_mouseout_event=on_molstar_mouseout, |
|
hide_controls_icon=True, |
|
hide_expand_icon=True, |
|
hide_settings_icon=True, |
|
expanded=fullscreen.value, |
|
).key(f"molstar-{dark_effective}") |
|
if not fullscreen.value: |
|
with solara.Card(style={"height": "550px"}): |
|
if data.is_empty(): |
|
solara.Text("No data") |
|
else: |
|
ScatterChart( |
|
data, |
|
colors, |
|
axis_properties, |
|
on_chart_selection, |
|
line_number.value, |
|
) |
|
|
|
|
|
@solara.component |
|
def RoutedView(): |
|
route = solara.use_router() |
|
dark_effective = solara.lab.use_dark_effective() |
|
|
|
try: |
|
query_dict = {k: v for k, v in parse_qsl(route.search)} |
|
colors = ColorTransform(**query_dict) |
|
axis_properties = AxisProperties(**query_dict) |
|
data = decode_data(query_dict["data"]) |
|
ProteinView( |
|
query_dict["title"], |
|
molecule_id=query_dict["molecule_id"], |
|
data=data, |
|
colors=colors, |
|
axis_properties=axis_properties, |
|
dark_effective=dark_effective, |
|
description=query_dict.get("description", ""), |
|
) |
|
except KeyError as err: |
|
solara.Warning(f"Error: {err}") |
|
|