lmoss commited on
Commit
fb968f9
·
1 Parent(s): 42ff5a0

added a dummy to not have to write pythreejs html to disk

Browse files
Files changed (1) hide show
  1. app.py +31 -11
app.py CHANGED
@@ -7,16 +7,25 @@ import requests
7
  import numpy as np
8
  import numpy.typing as npt
9
  from dcgan import DCGAN3D_G
10
- import os
11
  import pathlib
12
  pv.start_xvfb()
13
 
 
 
 
 
 
 
 
 
 
14
  STREAMLIT_STATIC_PATH = pathlib.Path(st.__path__[0]) / 'static'
15
 
16
  DOWNLOADS_PATH = (STREAMLIT_STATIC_PATH / "downloads")
17
  if not DOWNLOADS_PATH.is_dir():
18
  DOWNLOADS_PATH.mkdir()
19
 
 
20
  def download_checkpoint(url: str, path: str) -> None:
21
  resp = requests.get(url)
22
 
@@ -24,15 +33,22 @@ def download_checkpoint(url: str, path: str) -> None:
24
  f.write(resp.content)
25
 
26
 
27
- def generate_image(path: str,
 
28
  image_size: int = 64,
29
  z_dim: int = 512,
30
  n_channels: int = 1,
31
  n_features: int = 32,
32
- ngpu: int = 1,
33
- latent_size: int = 3) -> npt.ArrayLike:
34
  netG = DCGAN3D_G(image_size, z_dim, n_channels, n_features, ngpu)
35
  netG.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
 
 
 
 
 
 
 
36
  z = torch.randn(1, z_dim, latent_size, latent_size, latent_size)
37
  with torch.no_grad():
38
  X = netG(z)
@@ -68,6 +84,7 @@ def create_matplotlib_figure(img: npt.ArrayLike, midpoint: int):
68
  a.set_axis_off()
69
  return fig
70
 
 
71
  def main():
72
  st.title("Generating Porous Media with GANs")
73
 
@@ -106,34 +123,37 @@ def main():
106
 
107
  if not (DOWNLOADS_PATH / model_fname).exists():
108
  download_checkpoint(checkpoint_url, (DOWNLOADS_PATH / model_fname))
 
109
 
110
  latent_size = st.slider("Latent Space Size z", min_value=1, max_value=5, step=1)
111
- img = generate_image((DOWNLOADS_PATH / model_fname), latent_size=latent_size)
112
  slices, mesh, dist = create_uniform_mesh_marching_cubes(img)
113
 
114
  pv.set_plot_theme("document")
115
  pl = pv.Plotter(shape=(1, 1),
116
  window_size=(view_width, view_height))
117
  _ = pl.add_mesh(slices, cmap="gray")
118
- pl.export_html((DOWNLOADS_PATH / 'slices.html'))
 
 
119
 
120
  pl = pv.Plotter(shape=(1, 1),
121
  window_size=(view_width, view_height))
122
  _ = pl.add_mesh(mesh, scalars=dist)
123
- pl.export_html((DOWNLOADS_PATH / 'mesh.html'))
 
124
 
125
  st.header("2D Cross-Section of Generated Volume")
126
  fig = create_matplotlib_figure(img, img.shape[0]//2)
127
  st.pyplot(fig=fig)
128
 
129
- HtmlFile = open((DOWNLOADS_PATH / 'slices.html'), 'r', encoding='utf-8')
130
- source_code = HtmlFile.read()
131
  st.header("3D Intersections")
132
  components.html(source_code, width=view_width, height=view_height)
133
  st.markdown("_Click and drag to spin, right click to shift._")
134
 
135
- HtmlFile = open((DOWNLOADS_PATH / 'mesh.html'), 'r', encoding='utf-8')
136
- source_code = HtmlFile.read()
137
  st.header("3D Pore Space Mesh")
138
  components.html(source_code, width=view_width, height=view_height)
139
  st.markdown("_Click and drag to spin, right click to shift._")
 
7
  import numpy as np
8
  import numpy.typing as npt
9
  from dcgan import DCGAN3D_G
 
10
  import pathlib
11
  pv.start_xvfb()
12
 
13
+
14
+ class DummyWriteable(object):
15
+ def __init__(self):
16
+ self.html = None
17
+
18
+ def write(self, html):
19
+ self.html = html
20
+
21
+
22
  STREAMLIT_STATIC_PATH = pathlib.Path(st.__path__[0]) / 'static'
23
 
24
  DOWNLOADS_PATH = (STREAMLIT_STATIC_PATH / "downloads")
25
  if not DOWNLOADS_PATH.is_dir():
26
  DOWNLOADS_PATH.mkdir()
27
 
28
+
29
  def download_checkpoint(url: str, path: str) -> None:
30
  resp = requests.get(url)
31
 
 
33
  f.write(resp.content)
34
 
35
 
36
+ @st.cache(persist=True, allow_output_mutation=True)
37
+ def load_model(path: str,
38
  image_size: int = 64,
39
  z_dim: int = 512,
40
  n_channels: int = 1,
41
  n_features: int = 32,
42
+ ngpu: int = 1,) -> torch.nn.Module:
 
43
  netG = DCGAN3D_G(image_size, z_dim, n_channels, n_features, ngpu)
44
  netG.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
45
+ return netG
46
+
47
+
48
+ @st.cache()
49
+ def generate_image(netG: torch.nn.Module,
50
+ z_dim: int = 512,
51
+ latent_size: int = 3) -> npt.ArrayLike:
52
  z = torch.randn(1, z_dim, latent_size, latent_size, latent_size)
53
  with torch.no_grad():
54
  X = netG(z)
 
84
  a.set_axis_off()
85
  return fig
86
 
87
+
88
  def main():
89
  st.title("Generating Porous Media with GANs")
90
 
 
123
 
124
  if not (DOWNLOADS_PATH / model_fname).exists():
125
  download_checkpoint(checkpoint_url, (DOWNLOADS_PATH / model_fname))
126
+ netG = load_model((DOWNLOADS_PATH / model_fname))
127
 
128
  latent_size = st.slider("Latent Space Size z", min_value=1, max_value=5, step=1)
129
+ img = generate_image(netG, latent_size=latent_size)
130
  slices, mesh, dist = create_uniform_mesh_marching_cubes(img)
131
 
132
  pv.set_plot_theme("document")
133
  pl = pv.Plotter(shape=(1, 1),
134
  window_size=(view_width, view_height))
135
  _ = pl.add_mesh(slices, cmap="gray")
136
+
137
+ slices_html = DummyWriteable()
138
+ pl.export_html(slices_html)
139
 
140
  pl = pv.Plotter(shape=(1, 1),
141
  window_size=(view_width, view_height))
142
  _ = pl.add_mesh(mesh, scalars=dist)
143
+ mesh_html = DummyWriteable()
144
+ pl.export_html(mesh_html)
145
 
146
  st.header("2D Cross-Section of Generated Volume")
147
  fig = create_matplotlib_figure(img, img.shape[0]//2)
148
  st.pyplot(fig=fig)
149
 
150
+ source_code = slices_html.html
 
151
  st.header("3D Intersections")
152
  components.html(source_code, width=view_width, height=view_height)
153
  st.markdown("_Click and drag to spin, right click to shift._")
154
 
155
+
156
+ source_code = mesh_html.html
157
  st.header("3D Pore Space Mesh")
158
  components.html(source_code, width=view_width, height=view_height)
159
  st.markdown("_Click and drag to spin, right click to shift._")