Spaces:
Running
Running
File size: 2,169 Bytes
49f65e4 269eef7 49f65e4 a8c5eee 49f65e4 a8c5eee 49f65e4 a8c5eee 49f65e4 c8067b8 d0b54f1 7f5a3aa 70913b7 d0b54f1 e198db2 d0b54f1 b0f7a83 d0b54f1 e198db2 d0b54f1 b0f7a83 d0b54f1 b0f7a83 e198db2 d0b54f1 b54a771 d0b54f1 b0f7a83 d0b54f1 b0f7a83 d0b54f1 7f5a3aa 49f65e4 269eef7 49f65e4 269eef7 43d4050 269eef7 43d4050 9698c7f f6f1fa1 ae52322 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import torch
from PIL import Image
from RealESRGAN import RealESRGAN
import gradio as gr
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model2 = RealESRGAN(device, scale=2)
model2.load_weights('weights/RealESRGAN_x2.pth', download=True)
model4 = RealESRGAN(device, scale=4)
model4.load_weights('weights/RealESRGAN_x4.pth', download=True)
model8 = RealESRGAN(device, scale=8)
model8.load_weights('weights/RealESRGAN_x8.pth', download=True)
def inference(image, size):
global model2
global model4
global model8
if image is None:
raise gr.Error("Image not uploaded")
if torch.cuda.is_available():
torch.cuda.empty_cache()
if size == '2x':
try:
result = model2.predict(image.convert('RGB'))
except torch.cuda.OutOfMemoryError as e:
print(e)
model2 = RealESRGAN(device, scale=2)
model2.load_weights('weights/RealESRGAN_x2.pth', download=False)
result = model2.predict(image.convert('RGB'))
elif size == '4x':
try:
result = model4.predict(image.convert('RGB'))
except torch.cuda.OutOfMemoryError as e:
print(e)
model4 = RealESRGAN(device, scale=4)
model4.load_weights('weights/RealESRGAN_x4.pth', download=False)
result = model2.predict(image.convert('RGB'))
else:
try:
width, height = image.size
if width >= 5000 or height >= 5000:
raise gr.Error("The image is too large.")
result = model8.predict(image.convert('RGB'))
except torch.cuda.OutOfMemoryError as e:
print(e)
model8 = RealESRGAN(device, scale=8)
model8.load_weights('weights/RealESRGAN_x8.pth', download=False)
result = model2.predict(image.convert('RGB'))
print(f"Image size ({device}): {size} ... OK")
return result
gr.Interface(inference,
[gr.Image(type="pil"),
gr.Radio(["2x", "4x", "8x"],
type="value",
value="2x",
label="Resolution model")],
gr.Image(type="pil", label="Output"),
).launch()
|