import torch import numpy as np import networkx as nx import matplotlib.pyplot as plt from PIL import Image from matplotlib import rc, patches, colors rc("font", **{"family": "serif", "serif": ["Roman"]}) rc("text", usetex=True) rc("image", interpolation="none") rc("text.latex", preamble=r"\usepackage{amsmath} \usepackage{amssymb}") from datasets import get_attr_max_min HAMMER = np.array(Image.open("./hammer.png").resize((35, 35))) / 255 class MidpointNormalize(colors.Normalize): def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False): self.midpoint = midpoint colors.Normalize.__init__(self, vmin, vmax, clip) def __call__(self, value, clip=None): v_ext = np.max([np.abs(self.vmin), np.abs(self.vmax)]) x, y = [-v_ext, self.midpoint, v_ext], [0, 0.5, 1] return np.ma.masked_array(np.interp(value, x, y)) def postprocess(x): return ((x + 1.0) * 127.5).squeeze().detach().cpu().numpy() def mnist_graph(*args): x, t, i, y = r"$\mathbf{x}$", r"$t$", r"$i$", r"$y$" ut, ui, uy = r"$\mathbf{U}_t$", r"$\mathbf{U}_i$", r"$\mathbf{U}_y$" zx, ex = r"$\mathbf{z}_{1:L}$", r"$\boldsymbol{\epsilon}$" G = nx.DiGraph() G.add_edge(t, x) G.add_edge(i, x) G.add_edge(y, x) G.add_edge(t, i) G.add_edge(ut, t) G.add_edge(ui, i) G.add_edge(uy, y) G.add_edge(zx, x) G.add_edge(ex, x) pos = { y: (0, 0), uy: (-1, 0), t: (0, 0.5), ut: (0, 1), x: (1, 0), zx: (2, 0.375), ex: (2, 0), i: (1, 0.5), ui: (1, 1), } node_c = {} for node in G: node_c[node] = "lightgrey" if node in [x, t, i, y] else "white" node_line_c = {k: "black" for k, _ in node_c.items()} edge_c = {e: "black" for e in G.edges} if args[0]: # do_t edge_c[(ut, t)] = "lightgrey" # G.remove_edge(ut, t) node_line_c[t] = "red" if args[1]: # do_i edge_c[(ui, i)] = "lightgrey" edge_c[(t, i)] = "lightgrey" # G.remove_edges_from([(ui, i), (t, i)]) node_line_c[i] = "red" if args[2]: # do_y edge_c[(uy, y)] = "lightgrey" # G.remove_edge(uy, y) node_line_c[y] = "red" fs = 30 options = { "font_size": fs, "node_size": 3000, "node_color": list(node_c.values()), "edgecolors": list(node_line_c.values()), "edge_color": list(edge_c.values()), "linewidths": 2, "width": 2, } plt.close("all") fig, ax = plt.subplots(1, 1, figsize=(6, 4.1)) # , constrained_layout=True) # fig.patch.set_visible(False) ax.margins(x=0.06, y=0.15, tight=False) ax.axis("off") nx.draw_networkx(G, pos, **options, arrowsize=25, arrowstyle="-|>", ax=ax) # need to reuse x, y limits so that the graphs plot the same way before and after removing edges x_lim = (-1.348, 2.348) y_lim = (-0.215, 1.215) ax.set_xlim(x_lim) ax.set_ylim(y_lim) rect = patches.FancyBboxPatch( (1.75, -0.16), 0.5, 0.7, boxstyle="round, pad=0.05, rounding_size=0", linewidth=2, edgecolor="black", facecolor="none", linestyle="-", ) ax.add_patch(rect) ax.text(1.85, 0.65, r"$\mathbf{U}_{\mathbf{x}}$", fontsize=fs) if args[0]: # do_t fig.figimage(HAMMER, 0.26 * fig.bbox.xmax, 0.525 * fig.bbox.ymax, zorder=10) if args[1]: # do_i fig.figimage(HAMMER, 0.5175 * fig.bbox.xmax, 0.525 * fig.bbox.ymax, zorder=11) if args[2]: # do_y fig.figimage(HAMMER, 0.26 * fig.bbox.xmax, 0.2 * fig.bbox.ymax, zorder=12) fig.tight_layout() fig.canvas.draw() return np.array(fig.canvas.renderer.buffer_rgba()) def brain_graph(*args): x, m, s, a, b, v = r"$\mathbf{x}$", r"$m$", r"$s$", r"$a$", r"$b$", r"$v$" um, us, ua, ub, uv = ( r"$\mathbf{U}_m$", r"$\mathbf{U}_s$", r"$\mathbf{U}_a$", r"$\mathbf{U}_b$", r"$\mathbf{U}_v$", ) zx, ex = r"$\mathbf{z}_{1:L}$", r"$\boldsymbol{\epsilon}$" G = nx.DiGraph() G.add_edge(m, x) G.add_edge(s, x) G.add_edge(b, x) G.add_edge(v, x) G.add_edge(zx, x) G.add_edge(ex, x) G.add_edge(a, b) G.add_edge(a, v) G.add_edge(s, b) G.add_edge(um, m) G.add_edge(us, s) G.add_edge(ua, a) G.add_edge(ub, b) G.add_edge(uv, v) pos = { x: (0, 0), zx: (-0.25, -1), ex: (0.25, -1), a: (0, 1), ua: (0, 2), s: (1, 0), us: (1, -1), b: (1, 1), ub: (1, 2), m: (-1, 0), um: (-1, -1), v: (-1, 1), uv: (-1, 2), } node_c = {} for node in G: node_c[node] = "lightgrey" if node in [x, m, s, a, b, v] else "white" node_line_c = {k: "black" for k, _ in node_c.items()} edge_c = {e: "black" for e in G.edges} if args[0]: # do_m # G.remove_edge(um, m) edge_c[(um, m)] = "lightgrey" node_line_c[m] = "red" if args[1]: # do_s # G.remove_edge(us, s) edge_c[(us, s)] = "lightgrey" node_line_c[s] = "red" if args[2]: # do_a # G.remove_edge(ua, a) edge_c[(ua, a)] = "lightgrey" node_line_c[a] = "red" if args[3]: # do_b # G.remove_edges_from([(ub, b), (s, b), (a, b)]) edge_c[(ub, b)] = "lightgrey" edge_c[(s, b)] = "lightgrey" edge_c[(a, b)] = "lightgrey" node_line_c[b] = "red" if args[4]: # do_v # G.remove_edges_from([(uv, v), (a, v), (b, v)]) edge_c[(uv, v)] = "lightgrey" edge_c[(a, v)] = "lightgrey" edge_c[(b, v)] = "lightgrey" node_line_c[v] = "red" fs = 30 options = { "font_size": fs, "node_size": 3000, "node_color": list(node_c.values()), "edgecolors": list(node_line_c.values()), "edge_color": list(edge_c.values()), "linewidths": 2, "width": 2, } plt.close("all") fig, ax = plt.subplots(1, 1, figsize=(5, 5)) # , constrained_layout=True) # fig.patch.set_visible(False) ax.margins(x=0.1, y=0.08, tight=False) ax.axis("off") nx.draw_networkx(G, pos, **options, arrowsize=25, arrowstyle="-|>", ax=ax) # need to reuse x, y limits so that the graphs plot the same way before and after removing edges x_lim = (-1.32, 1.32) y_lim = (-1.414, 2.414) ax.set_xlim(x_lim) ax.set_ylim(y_lim) rect = patches.FancyBboxPatch( (-0.5, -1.325), 1, 0.65, boxstyle="round, pad=0.05, rounding_size=0", linewidth=2, edgecolor="black", facecolor="none", linestyle="-", ) ax.add_patch(rect) # ax.text(1.85, 0.65, r"$\mathbf{U}_{\mathbf{x}}$", fontsize=fs) if args[0]: # do_m fig.figimage(HAMMER, 0.0075 * fig.bbox.xmax, 0.395 * fig.bbox.ymax, zorder=10) if args[1]: # do_s fig.figimage(HAMMER, 0.72 * fig.bbox.xmax, 0.395 * fig.bbox.ymax, zorder=11) if args[2]: # do_a fig.figimage(HAMMER, 0.363 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=12) if args[3]: # do_b fig.figimage(HAMMER, 0.72 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=13) if args[4]: # do_v fig.figimage(HAMMER, 0.0075 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=14) else: # b -> v a3 = patches.FancyArrowPatch( (0.86, 1.21), (-0.86, 1.21), connectionstyle="arc3,rad=.3", linewidth=2, arrowstyle="simple, head_width=10, head_length=10", color="k", ) ax.add_patch(a3) # print(ax.get_xlim()) # print(ax.get_ylim()) fig.tight_layout() fig.canvas.draw() return np.array(fig.canvas.renderer.buffer_rgba()) def chest_graph(*args): x, a, d, r, s = r"$\mathbf{x}$", r"$a$", r"$d$", r"$r$", r"$s$" ua, ud, ur, us = ( r"$\mathbf{U}_a$", r"$\mathbf{U}_d$", r"$\mathbf{U}_r$", r"$\mathbf{U}_s$", ) zx, ex = r"$\mathbf{z}_{1:L}$", r"$\boldsymbol{\epsilon}$" G = nx.DiGraph() G.add_edge(ua, a) G.add_edge(ud, d) G.add_edge(ur, r) G.add_edge(us, s) G.add_edge(a, d) G.add_edge(d, x) G.add_edge(r, x) G.add_edge(s, x) G.add_edge(ex, x) G.add_edge(zx, x) G.add_edge(a, x) pos = { x: (0, 0), a: (-1, 1), d: (0, 1), r: (1, 1), s: (1, 0), ua: (-1, 2), ud: (0, 2), ur: (1, 2), us: (1, -1), zx: (-0.25, -1), ex: (0.25, -1), } node_c = {} for node in G: node_c[node] = "lightgrey" if node in [x, a, d, r, s] else "white" edge_c = {e: "black" for e in G.edges} node_line_c = {k: "black" for k, _ in node_c.items()} if args[0]: # do_r # G.remove_edge(ur, r) edge_c[(ur, r)] = "lightgrey" node_line_c[r] = "red" if args[1]: # do_s # G.remove_edges_from([(us, s)]) edge_c[(us, s)] = "lightgrey" node_line_c[s] = "red" if args[2]: # do_f (do_d) # G.remove_edges_from([(ud, d), (a, d)]) edge_c[(ud, d)] = "lightgrey" edge_c[(a, d)] = "lightgrey" node_line_c[d] = "red" if args[3]: # do_a # G.remove_edge(ua, a) edge_c[(ua, a)] = "lightgrey" node_line_c[a] = "red" fs = 30 options = { "font_size": fs, "node_size": 3000, "node_color": list(node_c.values()), "edgecolors": list(node_line_c.values()), "edge_color": list(edge_c.values()), "linewidths": 2, "width": 2, } plt.close("all") fig, ax = plt.subplots(1, 1, figsize=(5, 5)) # , constrained_layout=True) # fig.patch.set_visible(False) ax.margins(x=0.1, y=0.08, tight=False) ax.axis("off") nx.draw_networkx(G, pos, **options, arrowsize=25, arrowstyle="-|>", ax=ax) # need to reuse x, y limits so that the graphs plot the same way before and after removing edges x_lim = (-1.32, 1.32) y_lim = (-1.414, 2.414) ax.set_xlim(x_lim) ax.set_ylim(y_lim) rect = patches.FancyBboxPatch( (-0.5, -1.325), 1, 0.65, boxstyle="round, pad=0.05, rounding_size=0", linewidth=2, edgecolor="black", facecolor="none", linestyle="-", ) ax.add_patch(rect) ax.text(-0.9, -1.075, r"$\mathbf{U}_{\mathbf{x}}$", fontsize=fs) if args[0]: # do_r fig.figimage(HAMMER, 0.72 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=10) if args[1]: # do_s fig.figimage(HAMMER, 0.72 * fig.bbox.xmax, 0.395 * fig.bbox.ymax, zorder=11) if args[2]: # do_f fig.figimage(HAMMER, 0.363 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=12) if args[3]: # do_a fig.figimage(HAMMER, 0.0075 * fig.bbox.xmax, 0.64 * fig.bbox.ymax, zorder=13) fig.tight_layout() fig.canvas.draw() return np.array(fig.canvas.renderer.buffer_rgba()) def vae_preprocess(args, pa): if "ukbb" in args.hps: # preprocessing ukbb parents for the vae which was originally trained using # log standardized parents. The pgm was trained using [-1,1] normalization # first undo [-1,1] parent preprocessing back to original range for k, v in pa.items(): if k != "mri_seq" and k != "sex": pa[k] = (v + 1) / 2 # [-1,1] -> [0,1] _max, _min = get_attr_max_min(k) pa[k] = pa[k] * (_max - _min) + _min # log_standardize parents for vae input for k, v in pa.items(): logpa_k = torch.log(v.clamp(min=1e-12)) if k == "age": pa[k] = (logpa_k - 4.112339973449707) / 0.11769197136163712 elif k == "brain_volume": pa[k] = (logpa_k - 13.965583801269531) / 0.09537758678197861 elif k == "ventricle_volume": pa[k] = (logpa_k - 10.345998764038086) / 0.43127763271331787 # concatenate parents expand to input res for conditioning the vae pa = torch.cat( [pa[k] if len(pa[k].shape) > 1 else pa[k][..., None] for k in args.parents_x], dim=1, ) pa = ( pa[..., None, None].repeat(1, 1, *(args.input_res,) * 2).to(args.device).float() ) return pa def preprocess_brain(args, obs): obs["x"] = (obs["x"][None, ...].float().to(args.device) - 127.5) / 127.5 # [-1,1] # for all other variables except x for k in [k for k in obs.keys() if k != "x"]: obs[k] = obs[k].float().to(args.device).view(1, 1) if k in ["age", "brain_volume", "ventricle_volume"]: k_max, k_min = get_attr_max_min(k) obs[k] = (obs[k] - k_min) / (k_max - k_min) # [0,1] obs[k] = 2 * obs[k] - 1 # [-1,1] return obs def get_fig_arr(x, width=4, height=4, dpi=144, cmap="Greys_r", norm=None): fig = plt.figure(figsize=(width, height), dpi=dpi) ax = plt.axes([0, 0, 1, 1], frameon=False) if cmap == "Greys_r": ax.imshow(x, cmap=cmap, vmin=0, vmax=255) else: ax.imshow(x, cmap=cmap, norm=norm) ax.axis("off") fig.canvas.draw() return np.array(fig.canvas.renderer.buffer_rgba()) def normalize(x, x_min=None, x_max=None, zero_one=False): if x_min is None: x_min = x.min() if x_max is None: x_max = x.max() x = (x - x_min) / (x_max - x_min) # [0,1] return x if zero_one else 2 * x - 1 # else [-1,1]