jayparmr's picture
update : inference
35575bb verified
import math
from typing import Dict, List, Optional
from PIL import Image
from internals.data.result import Result
from internals.pipelines.commons import AbstractPipeline, Img2Img
from internals.util import get_generators
from internals.util.cache import clear_cuda_and_gc
from internals.util.config import (
get_base_dimension,
get_is_sdxl,
get_model_dir,
get_num_return_sequences,
)
from internals.util.sdxl_lightning import LightningMixin
class HighRes(AbstractPipeline, LightningMixin):
def load(self, img2img: Optional[Img2Img] = None):
if hasattr(self, "pipe"):
return
if not img2img:
img2img = Img2Img()
img2img.load(get_model_dir())
self.pipe = img2img.pipe
self.img2img = img2img
if get_is_sdxl():
self.configure_sdxl_lightning(img2img.pipe)
def apply(
self,
prompt: List[str],
negative_prompt: List[str],
images,
width: int,
height: int,
seed: int,
num_inference_steps: int,
strength: float = 0.5,
guidance_scale: int = 9,
**kwargs,
):
clear_cuda_and_gc()
generator = get_generators(seed, get_num_return_sequences())
images = [image.resize((width, height)) for image in images]
# if get_is_sdxl():
# kwargs["guidance_scale"] = kwargs.get("guidance_scale", 15)
# kwargs["strength"] = kwargs.get("strength", 0.6)
if get_is_sdxl():
extra_args = self.enable_sdxl_lightning()
kwargs.update(extra_args)
kwargs = {
"prompt": prompt,
"image": images,
"strength": strength,
"negative_prompt": negative_prompt,
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
"generator": generator,
**kwargs,
}
print(kwargs)
result = self.pipe.__call__(**kwargs)
if get_is_sdxl():
self.disable_sdxl_lightning()
return Result.from_result(result)
@staticmethod
def get_intermediate_dimension(target_width: int, target_height: int):
def_size = get_base_dimension()
desired_pixel_count = def_size * def_size
actual_pixel_count = target_width * target_height
scale = math.sqrt(desired_pixel_count / actual_pixel_count)
firstpass_width = math.ceil(scale * target_width / 64) * 64
firstpass_height = math.ceil(scale * target_height / 64) * 64
print("Pass1", firstpass_width, firstpass_height)
if get_is_sdxl():
firstpass_width, firstpass_height = HighRes.find_closest_sdxl_aspect_ratio(
firstpass_width, firstpass_height
)
print("Pass2", firstpass_width, firstpass_height)
return firstpass_width, firstpass_height
@staticmethod
def find_closest_sdxl_aspect_ratio(target_width: int, target_height: int):
target_ratio = target_width / target_height
closest_ratio = ""
min_difference = float("inf")
for ratio_str, (width, height) in SD_XL_BASE_RATIOS.items():
ratio = width / height
difference = abs(target_ratio - ratio)
if difference < min_difference:
min_difference = difference
closest_ratio = ratio_str
new_width, new_height = SD_XL_BASE_RATIOS[closest_ratio]
return new_width, new_height
SD_XL_BASE_RATIOS = {
"0.5": (704, 1408),
"0.52": (704, 1344),
"0.57": (768, 1344),
"0.6": (768, 1280),
"0.68": (832, 1216),
"0.72": (832, 1152),
"0.78": (896, 1152),
"0.82": (896, 1088),
"0.88": (960, 1088),
"0.94": (960, 1024),
"1.0": (1024, 1024),
"1.07": (1024, 960),
"1.13": (1088, 960),
"1.21": (1088, 896),
"1.29": (1152, 896),
"1.38": (1152, 832),
"1.46": (1216, 832),
"1.67": (1280, 768),
"1.75": (1344, 768),
"1.91": (1344, 704),
"2.0": (1408, 704),
"2.09": (1472, 704),
"2.4": (1536, 640),
"2.5": (1600, 640),
"2.89": (1664, 576),
"3.0": (1728, 576),
}