from htbuilder import HtmlElement, div, ul, li, br, hr, a, p, img, styles, classes, fonts from htbuilder.units import percent, px from htbuilder.funcs import rgba, rgb import streamlit as st import os import sys import argparse import clip import numpy as np from PIL import Image from dalle.models import Dalle from dalle.utils.utils import set_seed, clip_score import streamlit.components.v1 as components import torch #from IPython.display import display import random def link(link, text, **style): return a(_href=link, _target="_blank", style=styles(**style))(text) def layout(*args): style = """ """ style_div = styles( position="fixed", left=0, bottom=0, margin=px(0, 0, 0, 0), width=percent(100), color="black", text_align="center", height="auto", opacity=1 ) style_hr = styles( display="block", margin=px(8, 8, "auto", "auto"), border_style="inset", border_width=px(2) ) body = p() foot = div( style=style_div )( hr( style=style_hr ), body ) st.markdown(style, unsafe_allow_html=True) for arg in args: if isinstance(arg, str): body(arg) elif isinstance(arg, HtmlElement): body(arg) st.markdown(str(foot), unsafe_allow_html=True) def footer(): #myargs = [] #layout(*myargs) style = """ """ st.markdown(style, unsafe_allow_html=True) st.markdown("") st.markdown("") st.markdown("") st.markdown("This app uses the [min(DALL·E)](https://github.com/kuprel/min-dalle) port of [DALL·E mini](https://github.com/borisdayma/dalle-mini)") st.markdown("Created by [Jonathan Malott](https://jonathanmalott.com)") st.markdown("[Good Systems Grand Challenge](https://bridgingbarriers.utexas.edu/good-systems), The University of Texas at Austin. Advised by Dr. Junfeng Jiao.") from min_dalle import MinDalle def generate2(prompt,crazy,k): mm = MinDalle( models_root='./pretrained', dtype=torch.float32, device='cpu', is_mega=False, is_reusable=True ) # Sampling newPrompt = prompt if("architecture" not in prompt.lower() ): newPrompt += " architecture" image = mm.generate_image( text=newPrompt, seed=np.random.randint(0,10000), grid_size=1, is_seamless=False, temperature=crazy, top_k=k,#2128, supercondition_factor=32, is_verbose=False ) item = {} item['prompt'] = prompt item['crazy'] = crazy item['k'] = k item['image'] = image st.session_state.results.append(item) model = False def generate(prompt,crazy,k): global model device = 'cpu' if(model == False): model = Dalle.from_pretrained('minDALL-E/1.3B') # This will automatically download the pretrained model. model.to(device=device) num_candidates = 1 images = [] set_seed(np.random.randint(0,10000)) # Sampling newPrompt = prompt if("architecture" not in prompt.lower() ): newPrompt += " architecture" images = model.sampling(prompt=newPrompt, top_k=k, top_p=None, softmax_temperature=crazy, num_candidates=num_candidates, device=device).cpu().numpy() images = np.transpose(images, (0, 2, 3, 1)) # CLIP Re-ranking model_clip, preprocess_clip = clip.load("ViT-B/32", device=device) model_clip.to(device=device) rank = clip_score(prompt=newPrompt, images=images, model_clip=model_clip, preprocess_clip=preprocess_clip, device=device) result = images[rank] item = {} item['prompt'] = prompt item['crazy'] = crazy item['k'] = k item['image'] = Image.fromarray((result*255).astype(np.uint8)) st.session_state.results.append(item) def drawGrid(): master = {} for r in st.session_state.results[::-1]: _txt = r['prompt']+" "+str(r['crazy'])+" "+str(r['k']) if(_txt not in master): master[_txt] = [r] else: master[_txt].append(r) for i in st.session_state.images: im = st.empty() placeholder = st.empty() with placeholder.container(): for m in master: txt = master[m][0]['prompt']+" (Temperature:"+ str(master[m][0]['crazy']) + ", Top K:" + str(master[m][0]['k']) + ")" st.subheader(txt) col1, col2, col3 = st.columns(3) for ix, item in enumerate(master[m]): if ix % 3 == 0: with col1: st.session_state.images.append(st.image(item["image"])) if ix % 3 == 1: with col2: st.session_state.images.append(st.image(item["image"])) if ix % 3 == 2: with col3: st.session_state.images.append(st.image(item["image"]))