Spaces:
Runtime error
Runtime error
lmoss
commited on
Commit
·
fb968f9
1
Parent(s):
42ff5a0
added a dummy to not have to write pythreejs html to disk
Browse files
app.py
CHANGED
@@ -7,16 +7,25 @@ import requests
|
|
7 |
import numpy as np
|
8 |
import numpy.typing as npt
|
9 |
from dcgan import DCGAN3D_G
|
10 |
-
import os
|
11 |
import pathlib
|
12 |
pv.start_xvfb()
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
STREAMLIT_STATIC_PATH = pathlib.Path(st.__path__[0]) / 'static'
|
15 |
|
16 |
DOWNLOADS_PATH = (STREAMLIT_STATIC_PATH / "downloads")
|
17 |
if not DOWNLOADS_PATH.is_dir():
|
18 |
DOWNLOADS_PATH.mkdir()
|
19 |
|
|
|
20 |
def download_checkpoint(url: str, path: str) -> None:
|
21 |
resp = requests.get(url)
|
22 |
|
@@ -24,15 +33,22 @@ def download_checkpoint(url: str, path: str) -> None:
|
|
24 |
f.write(resp.content)
|
25 |
|
26 |
|
27 |
-
|
|
|
28 |
image_size: int = 64,
|
29 |
z_dim: int = 512,
|
30 |
n_channels: int = 1,
|
31 |
n_features: int = 32,
|
32 |
-
ngpu: int = 1,
|
33 |
-
latent_size: int = 3) -> npt.ArrayLike:
|
34 |
netG = DCGAN3D_G(image_size, z_dim, n_channels, n_features, ngpu)
|
35 |
netG.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
z = torch.randn(1, z_dim, latent_size, latent_size, latent_size)
|
37 |
with torch.no_grad():
|
38 |
X = netG(z)
|
@@ -68,6 +84,7 @@ def create_matplotlib_figure(img: npt.ArrayLike, midpoint: int):
|
|
68 |
a.set_axis_off()
|
69 |
return fig
|
70 |
|
|
|
71 |
def main():
|
72 |
st.title("Generating Porous Media with GANs")
|
73 |
|
@@ -106,34 +123,37 @@ def main():
|
|
106 |
|
107 |
if not (DOWNLOADS_PATH / model_fname).exists():
|
108 |
download_checkpoint(checkpoint_url, (DOWNLOADS_PATH / model_fname))
|
|
|
109 |
|
110 |
latent_size = st.slider("Latent Space Size z", min_value=1, max_value=5, step=1)
|
111 |
-
img = generate_image(
|
112 |
slices, mesh, dist = create_uniform_mesh_marching_cubes(img)
|
113 |
|
114 |
pv.set_plot_theme("document")
|
115 |
pl = pv.Plotter(shape=(1, 1),
|
116 |
window_size=(view_width, view_height))
|
117 |
_ = pl.add_mesh(slices, cmap="gray")
|
118 |
-
|
|
|
|
|
119 |
|
120 |
pl = pv.Plotter(shape=(1, 1),
|
121 |
window_size=(view_width, view_height))
|
122 |
_ = pl.add_mesh(mesh, scalars=dist)
|
123 |
-
|
|
|
124 |
|
125 |
st.header("2D Cross-Section of Generated Volume")
|
126 |
fig = create_matplotlib_figure(img, img.shape[0]//2)
|
127 |
st.pyplot(fig=fig)
|
128 |
|
129 |
-
|
130 |
-
source_code = HtmlFile.read()
|
131 |
st.header("3D Intersections")
|
132 |
components.html(source_code, width=view_width, height=view_height)
|
133 |
st.markdown("_Click and drag to spin, right click to shift._")
|
134 |
|
135 |
-
|
136 |
-
source_code =
|
137 |
st.header("3D Pore Space Mesh")
|
138 |
components.html(source_code, width=view_width, height=view_height)
|
139 |
st.markdown("_Click and drag to spin, right click to shift._")
|
|
|
7 |
import numpy as np
|
8 |
import numpy.typing as npt
|
9 |
from dcgan import DCGAN3D_G
|
|
|
10 |
import pathlib
|
11 |
pv.start_xvfb()
|
12 |
|
13 |
+
|
14 |
+
class DummyWriteable(object):
|
15 |
+
def __init__(self):
|
16 |
+
self.html = None
|
17 |
+
|
18 |
+
def write(self, html):
|
19 |
+
self.html = html
|
20 |
+
|
21 |
+
|
22 |
STREAMLIT_STATIC_PATH = pathlib.Path(st.__path__[0]) / 'static'
|
23 |
|
24 |
DOWNLOADS_PATH = (STREAMLIT_STATIC_PATH / "downloads")
|
25 |
if not DOWNLOADS_PATH.is_dir():
|
26 |
DOWNLOADS_PATH.mkdir()
|
27 |
|
28 |
+
|
29 |
def download_checkpoint(url: str, path: str) -> None:
|
30 |
resp = requests.get(url)
|
31 |
|
|
|
33 |
f.write(resp.content)
|
34 |
|
35 |
|
36 |
+
@st.cache(persist=True, allow_output_mutation=True)
|
37 |
+
def load_model(path: str,
|
38 |
image_size: int = 64,
|
39 |
z_dim: int = 512,
|
40 |
n_channels: int = 1,
|
41 |
n_features: int = 32,
|
42 |
+
ngpu: int = 1,) -> torch.nn.Module:
|
|
|
43 |
netG = DCGAN3D_G(image_size, z_dim, n_channels, n_features, ngpu)
|
44 |
netG.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
|
45 |
+
return netG
|
46 |
+
|
47 |
+
|
48 |
+
@st.cache()
|
49 |
+
def generate_image(netG: torch.nn.Module,
|
50 |
+
z_dim: int = 512,
|
51 |
+
latent_size: int = 3) -> npt.ArrayLike:
|
52 |
z = torch.randn(1, z_dim, latent_size, latent_size, latent_size)
|
53 |
with torch.no_grad():
|
54 |
X = netG(z)
|
|
|
84 |
a.set_axis_off()
|
85 |
return fig
|
86 |
|
87 |
+
|
88 |
def main():
|
89 |
st.title("Generating Porous Media with GANs")
|
90 |
|
|
|
123 |
|
124 |
if not (DOWNLOADS_PATH / model_fname).exists():
|
125 |
download_checkpoint(checkpoint_url, (DOWNLOADS_PATH / model_fname))
|
126 |
+
netG = load_model((DOWNLOADS_PATH / model_fname))
|
127 |
|
128 |
latent_size = st.slider("Latent Space Size z", min_value=1, max_value=5, step=1)
|
129 |
+
img = generate_image(netG, latent_size=latent_size)
|
130 |
slices, mesh, dist = create_uniform_mesh_marching_cubes(img)
|
131 |
|
132 |
pv.set_plot_theme("document")
|
133 |
pl = pv.Plotter(shape=(1, 1),
|
134 |
window_size=(view_width, view_height))
|
135 |
_ = pl.add_mesh(slices, cmap="gray")
|
136 |
+
|
137 |
+
slices_html = DummyWriteable()
|
138 |
+
pl.export_html(slices_html)
|
139 |
|
140 |
pl = pv.Plotter(shape=(1, 1),
|
141 |
window_size=(view_width, view_height))
|
142 |
_ = pl.add_mesh(mesh, scalars=dist)
|
143 |
+
mesh_html = DummyWriteable()
|
144 |
+
pl.export_html(mesh_html)
|
145 |
|
146 |
st.header("2D Cross-Section of Generated Volume")
|
147 |
fig = create_matplotlib_figure(img, img.shape[0]//2)
|
148 |
st.pyplot(fig=fig)
|
149 |
|
150 |
+
source_code = slices_html.html
|
|
|
151 |
st.header("3D Intersections")
|
152 |
components.html(source_code, width=view_width, height=view_height)
|
153 |
st.markdown("_Click and drag to spin, right click to shift._")
|
154 |
|
155 |
+
|
156 |
+
source_code = mesh_html.html
|
157 |
st.header("3D Pore Space Mesh")
|
158 |
components.html(source_code, width=view_width, height=view_height)
|
159 |
st.markdown("_Click and drag to spin, right click to shift._")
|