JMalott commited on
Commit
d9a9ea3
1 Parent(s): c0cc5ae

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +40 -6
utils.py CHANGED
@@ -11,6 +11,9 @@ from PIL import Image
11
  from dalle.models import Dalle
12
  from dalle.utils.utils import set_seed, clip_score
13
  import streamlit.components.v1 as components
 
 
 
14
 
15
  def link(link, text, **style):
16
  return a(_href=link, _target="_blank", style=styles(**style))(text)
@@ -21,7 +24,7 @@ def layout(*args):
21
  <style>
22
  # MainMenu {visibility: hidden;}
23
  footer {visibility: hidden;}
24
- .stApp { bottom: 105px; }
25
  </style>
26
  """
27
 
@@ -91,10 +94,42 @@ def footer():
91
 
92
  gtag('config', 'G-SB6NJ9DQS7');
93
  </script>
94
- """,
95
- height=600,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  )
97
 
 
 
 
 
 
 
 
98
  model = False
99
  def generate(prompt,crazy,k):
100
  global model
@@ -116,7 +151,7 @@ def generate(prompt,crazy,k):
116
  newPrompt += " architecture"
117
 
118
  images = model.sampling(prompt=newPrompt,
119
- top_k=2048,
120
  top_p=None,
121
  softmax_temperature=crazy,
122
  num_candidates=num_candidates,
@@ -137,12 +172,11 @@ def generate(prompt,crazy,k):
137
  item = {}
138
  item['prompt'] = prompt
139
  item['crazy'] = crazy
140
- item['k'] = k
141
  item['image'] = Image.fromarray((result*255).astype(np.uint8))
142
  st.session_state.results.append(item)
143
 
144
 
145
-
146
  def drawGrid():
147
  master = {}
148
 
 
11
  from dalle.models import Dalle
12
  from dalle.utils.utils import set_seed, clip_score
13
  import streamlit.components.v1 as components
14
+ import torch
15
+ from IPython.display import display
16
+ import random
17
 
18
  def link(link, text, **style):
19
  return a(_href=link, _target="_blank", style=styles(**style))(text)
 
24
  <style>
25
  # MainMenu {visibility: hidden;}
26
  footer {visibility: hidden;}
27
+ .stApp { bottom: 125px; }
28
  </style>
29
  """
30
 
 
94
 
95
  gtag('config', 'G-SB6NJ9DQS7');
96
  </script>
97
+ """
98
+ )
99
+
100
+
101
+ from min_dalle import MinDalle
102
+
103
+ def generate2(prompt,crazy,k):
104
+
105
+
106
+ mm = MinDalle(
107
+ models_root='./pretrained',
108
+ dtype=torch.float32,
109
+ device='cpu',
110
+ is_mega=False,
111
+ is_reusable=True
112
+ )
113
+
114
+
115
+ image = mm.generate_image(
116
+ text=prompt,
117
+ seed=np.random.randint(0,10000),
118
+ grid_size=1,
119
+ is_seamless=False,
120
+ temperature=crazy,
121
+ top_k=k,#2128,
122
+ supercondition_factor=32,
123
+ is_verbose=False
124
  )
125
 
126
+ item = {}
127
+ item['prompt'] = prompt
128
+ item['crazy'] = crazy
129
+ item['k'] = k
130
+ item['image'] = image
131
+ st.session_state.results.append(item)
132
+
133
  model = False
134
  def generate(prompt,crazy,k):
135
  global model
 
151
  newPrompt += " architecture"
152
 
153
  images = model.sampling(prompt=newPrompt,
154
+ top_k=256,
155
  top_p=None,
156
  softmax_temperature=crazy,
157
  num_candidates=num_candidates,
 
172
  item = {}
173
  item['prompt'] = prompt
174
  item['crazy'] = crazy
175
+ item['k'] = 20
176
  item['image'] = Image.fromarray((result*255).astype(np.uint8))
177
  st.session_state.results.append(item)
178
 
179
 
 
180
  def drawGrid():
181
  master = {}
182