Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
from diffusers import UNet2DModel, DDPMPipeline, DDPMScheduler, DiffusionPipeline | |
import torch | |
import torch.nn.functional as F | |
from matplotlib import pyplot as plt | |
from PIL import Image | |
import spaces | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
pipeline = DiffusionPipeline.from_pretrained("gjbooth2/Unconditional_A4C_1").to(device) | |
#try to return dataframe | |
def image_gen(click,rows = 4,cols = 4): | |
images = pipeline(batch_size=16).images | |
w, h = images[0].size | |
grid = Image.new('L', size=(cols*w, rows*h)) | |
for i, image in enumerate(images): | |
grid.paste(image, box=(i%cols*w, i//cols*h)) | |
return grid | |
#return 'button clicked' | |
def image_gen_modified(rows=4,cols=4): | |
pic_hold = [] | |
model_output = pipeline(batch_size=16).images | |
count = 0 | |
for i in range(len(model_output)): | |
pic = np.array(model_output[i].convert('L')) | |
max_val = max([element for row in pic for element in row]) | |
min_val = min([element for row in pic for element in row]) | |
if min_val > 55: #for washed out images, set them to all black | |
normalized_pic = np.ones((128,128)) | |
pic_hold.append(Image.fromarray(np.uint8(normalized_pic))) | |
if min_val < 56: | |
def normalize_images(x,min_val,max_val): #normalize pixels to be more homogenous grayscale appearance | |
return 200*((x-min_val)/(max_val-min_val)) | |
vectorized_normalizer = np.vectorize(normalize_images) | |
normalized_pic = vectorized_normalizer(pic,min_val,max_val) | |
pic_hold.append(Image.fromarray(np.uint8(normalized_pic))) | |
count+=1 | |
w, h = model_output[0].size | |
grid = Image.new('L', size=(cols*w, rows*h)) | |
for i, image in enumerate(pic_hold): | |
grid.paste(image, box=(i%cols*w, i//cols*h)) | |
return grid | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown('CS 614 Greg Booth Vision Assignment') | |
gr.Markdown('This gradio app can be used to generate realistic cardiac ultrasound images.') | |
gr.HTML("<a href = "+'https://pocus.sg/topic/subcostal-4-chamber/'+" _target='blank'>" +'Example anatomy'+ "</a>") | |
with gr.Tab('Generate a cardiac ultrasound image'): | |
playground_btn = gr.Button(value='Push me some images! (may take a couple minutes depending on hardware)') | |
playground_out = gr.Image() | |
playground_btn.click(image_gen_modified,outputs = playground_out) | |
demo.launch() | |