|
import os |
|
import torch |
|
from gfpgan import GFPGANer |
|
from PIL import Image |
|
import cv2 |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_path='GFPGANv1.4.pth'): |
|
|
|
self.model_path = model_path |
|
self.bg_upsampler = None |
|
self.face_enhancer = GFPGANer( |
|
model_path=self.model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=self.bg_upsampler) |
|
|
|
|
|
os.makedirs('output', exist_ok=True) |
|
|
|
def enhance_image(self, image_path): |
|
|
|
img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) |
|
|
|
|
|
_, _, output = self.face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True) |
|
|
|
|
|
save_path = "output/enhanced_image.png" |
|
cv2.imwrite(save_path, output) |
|
|
|
return output, save_path |
|
|
|
|
|
if __name__ == "__main__": |
|
handler = EndpointHandler() |
|
test_image_path = 'path_to_test_image.jpg' |
|
enhanced_image, save_path = handler.enhance_image(test_image_path) |
|
print(f"Enhanced image saved at {save_path}") |
|
|