Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -8,6 +8,8 @@ import gradio as gr
|
|
8 |
import torch
|
9 |
from diffusers import StableDiffusionPipeline
|
10 |
from torch import autocast
|
|
|
|
|
11 |
#from PIL import Image
|
12 |
#from torchvision import transforms
|
13 |
|
@@ -18,6 +20,7 @@ openai.api_key = os.getenv('openaikey')
|
|
18 |
authtoken = os.getenv('authtoken')
|
19 |
|
20 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
21 |
pipe = StableDiffusionPipeline.from_pretrained("stale2000/sd-dnditem", torch_dtype=torch.float16, use_auth_token=authtoken)
|
22 |
pipe = pipe.to(device)
|
23 |
|
@@ -55,9 +58,13 @@ def predict(input, manual_query_repacement, history=[]):
|
|
55 |
# return images, False
|
56 |
# pipe.safety_checker = null_safety
|
57 |
|
58 |
-
|
59 |
-
|
|
|
60 |
|
|
|
|
|
|
|
61 |
for idx, im in enumerate(images):
|
62 |
im.save(f"{idx:06}.png")
|
63 |
|
|
|
8 |
import torch
|
9 |
from diffusers import StableDiffusionPipeline
|
10 |
from torch import autocast
|
11 |
+
|
12 |
+
from contextlib import nullcontext
|
13 |
#from PIL import Image
|
14 |
#from torchvision import transforms
|
15 |
|
|
|
20 |
authtoken = os.getenv('authtoken')
|
21 |
|
22 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
23 |
+
context = autocast if device == "cuda" else nullcontext
|
24 |
pipe = StableDiffusionPipeline.from_pretrained("stale2000/sd-dnditem", torch_dtype=torch.float16, use_auth_token=authtoken)
|
25 |
pipe = pipe.to(device)
|
26 |
|
|
|
58 |
# return images, False
|
59 |
# pipe.safety_checker = null_safety
|
60 |
|
61 |
+
|
62 |
+
#with autocast("cuda"):
|
63 |
+
# images = pipe(n_samples*[prompt], guidance_scale=scale).images
|
64 |
|
65 |
+
with context("cuda"):
|
66 |
+
images = pipe(n_samples*[prompt], guidance_scale=scale, num_inference_steps=5).images
|
67 |
+
|
68 |
for idx, im in enumerate(images):
|
69 |
im.save(f"{idx:06}.png")
|
70 |
|