CM2000112 / inference2.py
jayparmr's picture
Upload folder using huggingface_hub
830fe50
import os
from io import BytesIO
import torch
import internals.util.prompt as prompt_util
from internals.data.dataAccessor import update_db, update_db_source_failed
from internals.data.task import ModelType, Task, TaskType
from internals.pipelines.controlnets import ControlNet
from internals.pipelines.high_res import HighRes
from internals.pipelines.img_classifier import ImageClassifier
from internals.pipelines.img_to_text import Image2Text
from internals.pipelines.inpainter import InPainter
from internals.pipelines.object_remove import ObjectRemoval
from internals.pipelines.prompt_modifier import PromptModifier
from internals.pipelines.remove_background import RemoveBackground, RemoveBackgroundV2
from internals.pipelines.replace_background import ReplaceBackground
from internals.pipelines.safety_checker import SafetyChecker
from internals.pipelines.upscaler import Upscaler
from internals.util.avatar import Avatar
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda, clear_cuda_and_gc
from internals.util.commons import construct_default_s3_url, upload_image, upload_images
from internals.util.config import (
num_return_sequences,
set_configs_from_task,
set_model_config,
set_root_dir,
)
from internals.util.failure_hander import FailureHandler
from internals.util.lora_style import LoraStyle
from internals.util.model_loader import load_model_from_config
from internals.util.slack import Slack
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
auto_mode = False
slack = Slack()
prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
upscaler = Upscaler()
inpainter = InPainter()
controlnet = ControlNet()
safety_checker = SafetyChecker()
high_res = HighRes()
object_removal = ObjectRemoval()
remove_background_v2 = RemoveBackgroundV2()
replace_background = ReplaceBackground()
img2text = Image2Text()
img_classifier = ImageClassifier()
avatar = Avatar()
lora_style = LoraStyle()
def get_patched_prompt_tile_upscale(task: Task):
return prompt_util.get_patched_prompt_tile_upscale(
task, avatar, lora_style, img_classifier, img2text
)
@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def tile_upscale(task: Task):
output_key = "crecoAI/{}_tile_upscaler.png".format(task.get_taskId())
prompt = get_patched_prompt_tile_upscale(task)
controlnet.load_model("tile_upscaler")
lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
lora_patcher.patch()
kwargs = {
"imageUrl": task.get_imageUrl(),
"seed": task.get_seed(),
"num_inference_steps": task.get_steps(),
"negative_prompt": task.get_negative_prompt(),
"width": task.get_width(),
"height": task.get_height(),
"prompt": prompt,
"resize_dimension": task.get_resize_dimension(),
**task.cnt_kwargs(),
}
images, has_nsfw = controlnet.process(**kwargs)
generated_image_url = upload_image(images[0], output_key)
lora_patcher.cleanup()
controlnet.cleanup()
return {
"modified_prompts": prompt,
"generated_image_url": generated_image_url,
"has_nsfw": has_nsfw,
}
@update_db
@slack.auto_send_alert
def remove_bg(task: Task):
output_image = remove_background_v2.remove(
task.get_imageUrl(), model_type=task.get_modelType()
)
output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId())
upload_image(output_image, output_key)
return {"generated_image_url": construct_default_s3_url(output_key)}
@update_db
@slack.auto_send_alert
def inpaint(task: Task):
prompt = avatar.add_code_names(task.get_prompt())
if task.is_prompt_engineering():
prompt = prompt_modifier.modify(prompt)
else:
prompt = [prompt] * num_return_sequences
print({"prompts": prompt})
kwargs = {
"prompt": prompt,
"image_url": task.get_imageUrl(),
"mask_image_url": task.get_maskImageUrl(),
"width": task.get_width(),
"height": task.get_height(),
"seed": task.get_seed(),
"negative_prompt": [task.get_negative_prompt()] * num_return_sequences,
"num_inference_steps": task.get_steps(),
**task.ip_kwargs(),
}
images = inpainter.process(**kwargs)
generated_image_urls = upload_images(images, "_inpaint", task.get_taskId())
clear_cuda()
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
@update_db
@slack.auto_send_alert
def remove_object(task: Task):
output_key = "crecoAI/{}_object_remove.png".format(task.get_taskId())
images = object_removal.process(
image_url=task.get_imageUrl(),
mask_image_url=task.get_maskImageUrl(),
seed=task.get_seed(),
width=task.get_width(),
height=task.get_height(),
)
generated_image_urls = upload_image(images[0], output_key)
clear_cuda()
return {"generated_image_urls": generated_image_urls}
@update_db
@slack.auto_send_alert
def replace_bg(task: Task):
prompt = task.get_prompt()
if task.is_prompt_engineering():
prompt = prompt_modifier.modify(prompt)
else:
prompt = [prompt] * num_return_sequences
images, has_nsfw = replace_background.replace(
image=task.get_imageUrl(),
prompt=prompt,
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
seed=task.get_seed(),
width=task.get_width(),
height=task.get_height(),
steps=task.get_steps(),
conditioning_scale=task.rbg_controlnet_conditioning_scale(),
model_type=task.get_modelType(),
)
generated_image_urls = upload_images(images, "_replace_bg", task.get_taskId())
return {
"modified_prompts": prompt,
"generated_image_urls": generated_image_urls,
"has_nsfw": has_nsfw,
}
@update_db
@slack.auto_send_alert
def upscale_image(task: Task):
output_key = "crecoAI/{}_upscale.png".format(task.get_taskId())
out_img = None
if (
task.get_modelType() == ModelType.ANIME
or task.get_modelType() == ModelType.COMIC
):
print("Using Anime model")
out_img = upscaler.upscale_anime(
image=task.get_imageUrl(),
width=task.get_width(),
height=task.get_height(),
face_enhance=task.get_face_enhance(),
resize_dimension=task.get_resize_dimension(),
)
else:
print("Using Real model")
out_img = upscaler.upscale(
image=task.get_imageUrl(),
width=task.get_width(),
height=task.get_height(),
face_enhance=task.get_face_enhance(),
resize_dimension=task.get_resize_dimension(),
)
upload_image(BytesIO(out_img), output_key)
clear_cuda_and_gc()
return {"generated_image_url": construct_default_s3_url(output_key)}
def model_fn(model_dir):
print("Logs: model loaded .... starts")
config = load_model_from_config(model_dir)
set_model_config(config)
set_root_dir(__file__)
FailureHandler.register()
avatar.load_local(model_dir)
lora_style.load(model_dir)
prompt_modifier.load()
safety_checker.load()
object_removal.load(model_dir)
upscaler.load()
inpainter.load()
high_res.load()
controlnet.init(high_res)
replace_background.load(
upscaler=upscaler, remove_background=remove_background_v2, high_res=high_res
)
print("Logs: model loaded ....")
return
def load_model_by_task(task: Task):
if task.get_type() == TaskType.TILE_UPSCALE:
controlnet.load_model("tile_upscaler")
safety_checker.apply(controlnet)
@FailureHandler.clear
def predict_fn(data, pipe):
task = Task(data)
print("task is ", data)
FailureHandler.handle(task)
try:
# Set set_environment
set_configs_from_task(task)
# Load model based on task
load_model_by_task(task)
# Apply safety checker based on environment
safety_checker.apply(inpainter)
safety_checker.apply(replace_background)
safety_checker.apply(high_res)
# Fetch avatars
avatar.fetch_from_network(task.get_model_id())
task_type = task.get_type()
if task_type == TaskType.REMOVE_BG:
return remove_bg(task)
elif task_type == TaskType.INPAINT:
return inpaint(task)
elif task_type == TaskType.UPSCALE_IMAGE:
return upscale_image(task)
elif task_type == TaskType.OBJECT_REMOVAL:
return remove_object(task)
elif task_type == TaskType.REPLACE_BG:
return replace_bg(task)
elif task_type == TaskType.TILE_UPSCALE:
return tile_upscale(task)
elif task_type == TaskType.SYSTEM_CMD:
os.system(task.get_prompt())
else:
raise Exception("Invalid task type")
except Exception as e:
print(f"Error: {e}")
slack.error_alert(task, e)
controlnet.cleanup()
update_db_source_failed(task.get_sourceId(), task.get_userId())
return None