kumahiyo commited on
Commit
7e6f588
1 Parent(s): c98a613

add increase counts of image and reduce gpu usage

Browse files
Files changed (1) hide show
  1. main.py +18 -2
main.py CHANGED
@@ -4,6 +4,7 @@ import sys
4
  import re
5
  import random
6
  import torch
 
7
  from fastapi import FastAPI
8
  from fastapi.staticfiles import StaticFiles
9
  from pydantic import BaseModel
@@ -49,13 +50,16 @@ def draw(data: Data):
49
 
50
  #pipe = StableDiffusionPipeline.from_pretrained(model_id)
51
  pipe = StableDiffusionPipeline.from_pretrained(model_id, revision='fp16', torch_dtype=torch.float16)
 
52
  pipe = pipe.to('cuda')
53
 
54
  generator = torch.Generator("cuda").manual_seed(seed)
55
- image = pipe(prompt, negative_prompt=n_prompt, guidance_scale=7.5, generator=generator).images[0]
 
 
56
 
57
  fileName = "sd_" + str(time.time()) + '.png'
58
- image.save("/code/tmpdir/" + fileName)
59
 
60
  print(fileName)
61
 
@@ -64,3 +68,15 @@ def draw(data: Data):
64
  return {"status": "SORRY! This file is member only.", "file": ""}
65
 
66
  app.mount("/static", StaticFiles(directory="/code/tmpdir"), name="/static")
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import re
5
  import random
6
  import torch
7
+ from PIL import Image
8
  from fastapi import FastAPI
9
  from fastapi.staticfiles import StaticFiles
10
  from pydantic import BaseModel
 
50
 
51
  #pipe = StableDiffusionPipeline.from_pretrained(model_id)
52
  pipe = StableDiffusionPipeline.from_pretrained(model_id, revision='fp16', torch_dtype=torch.float16)
53
+ pipe.enable_attention_slicing() # reduce gpu usage
54
  pipe = pipe.to('cuda')
55
 
56
  generator = torch.Generator("cuda").manual_seed(seed)
57
+ images = pipe(prompt, negative_prompt=n_prompt, guidance_scale=7.5, generator=generator, num_images_per_prompt=3).images
58
+
59
+ grid = image_grid(images, rows=2, cols=2)
60
 
61
  fileName = "sd_" + str(time.time()) + '.png'
62
+ grid.save("/code/tmpdir/" + fileName)
63
 
64
  print(fileName)
65
 
 
68
  return {"status": "SORRY! This file is member only.", "file": ""}
69
 
70
  app.mount("/static", StaticFiles(directory="/code/tmpdir"), name="/static")
71
+
72
+ # helper function taken from: https://huggingface.co/blog/stable_diffusion
73
+ def image_grid(imgs, rows, cols):
74
+ assert len(imgs) == rows*cols
75
+
76
+ w, h = imgs[0].size
77
+ grid = Image.new('RGB', size=(cols*w, rows*h))
78
+ grid_w, grid_h = grid.size
79
+
80
+ for i, img in enumerate(imgs):
81
+ grid.paste(img, box=(i%cols*w, i//cols*h))
82
+ return grid