jayparmr commited on
Commit
2c6c92a
·
1 Parent(s): 7fbdac4

Upload folder using huggingface_hub

Browse files
inference.py CHANGED
@@ -19,6 +19,7 @@ from internals.pipelines.pose_detector import PoseDetector
19
  from internals.pipelines.prompt_modifier import PromptModifier
20
  from internals.pipelines.replace_background import ReplaceBackground
21
  from internals.pipelines.safety_checker import SafetyChecker
 
22
  from internals.util.args import apply_style_args
23
  from internals.util.avatar import Avatar
24
  from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda, clear_cuda_and_gc
@@ -55,6 +56,8 @@ img2img_pipe = Img2Img()
55
  safety_checker = SafetyChecker()
56
  slack = Slack()
57
  avatar = Avatar()
 
 
58
 
59
  custom_scripts: List = []
60
 
@@ -145,28 +148,42 @@ def tile_upscale(task: Task):
145
 
146
  prompt = get_patched_prompt_tile_upscale(task)
147
 
148
- controlnet.load_model("tile_upscaler")
 
 
 
 
 
 
 
 
 
 
 
149
 
150
- lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
151
- lora_patcher.patch()
 
152
 
153
- kwargs = {
154
- "imageUrl": task.get_imageUrl(),
155
- "seed": task.get_seed(),
156
- "num_inference_steps": task.get_steps(),
157
- "negative_prompt": task.get_negative_prompt(),
158
- "width": task.get_width(),
159
- "height": task.get_height(),
160
- "prompt": prompt,
161
- "resize_dimension": task.get_resize_dimension(),
162
- **task.cnt_kwargs(),
163
- }
164
- images, has_nsfw = controlnet.process(**kwargs)
165
 
166
- generated_image_url = upload_image(images[0], output_key)
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
- lora_patcher.cleanup()
169
- controlnet.cleanup()
170
 
171
  return {
172
  "modified_prompts": prompt,
@@ -582,7 +599,10 @@ def load_model_by_task(task: Task):
582
  replace_background.load(base=text2img_pipe, high_res=high_res)
583
  else:
584
  if task.get_type() == TaskType.TILE_UPSCALE:
585
- controlnet.load_model("tile_upscaler")
 
 
 
586
  elif task.get_type() == TaskType.CANNY:
587
  controlnet.load_model("canny")
588
  elif task.get_type() == TaskType.SCRIBBLE:
 
19
  from internals.pipelines.prompt_modifier import PromptModifier
20
  from internals.pipelines.replace_background import ReplaceBackground
21
  from internals.pipelines.safety_checker import SafetyChecker
22
+ from internals.pipelines.sdxl_tile_upscale import SDXLTileUpscaler
23
  from internals.util.args import apply_style_args
24
  from internals.util.avatar import Avatar
25
  from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda, clear_cuda_and_gc
 
56
  safety_checker = SafetyChecker()
57
  slack = Slack()
58
  avatar = Avatar()
59
+ sdxl_tileupscaler = SDXLTileUpscaler()
60
+
61
 
62
  custom_scripts: List = []
63
 
 
148
 
149
  prompt = get_patched_prompt_tile_upscale(task)
150
 
151
+ if get_is_sdxl():
152
+ lora_patcher = lora_style.get_patcher(sdxl_tileupscaler.pipe, task.get_style())
153
+ lora_patcher.patch()
154
+
155
+ images, has_nsfw = sdxl_tileupscaler.process(
156
+ prompt=prompt,
157
+ imageUrl=task.get_imageUrl(),
158
+ resize_dimension=task.get_resize_dimension(),
159
+ negative_prompt=task.get_negative_prompt(),
160
+ width=task.get_width(),
161
+ height=task.get_height(),
162
+ )
163
 
164
+ lora_patcher.cleanup()
165
+ else:
166
+ controlnet.load_model("tile_upscaler")
167
 
168
+ lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
169
+ lora_patcher.patch()
 
 
 
 
 
 
 
 
 
 
170
 
171
+ kwargs = {
172
+ "imageUrl": task.get_imageUrl(),
173
+ "seed": task.get_seed(),
174
+ "num_inference_steps": task.get_steps(),
175
+ "negative_prompt": task.get_negative_prompt(),
176
+ "width": task.get_width(),
177
+ "height": task.get_height(),
178
+ "prompt": prompt,
179
+ "resize_dimension": task.get_resize_dimension(),
180
+ **task.cnt_kwargs(),
181
+ }
182
+ images, has_nsfw = controlnet.process(**kwargs)
183
+ lora_patcher.cleanup()
184
+ controlnet.cleanup()
185
 
186
+ generated_image_url = upload_image(images[0], output_key)
 
187
 
188
  return {
189
  "modified_prompts": prompt,
 
599
  replace_background.load(base=text2img_pipe, high_res=high_res)
600
  else:
601
  if task.get_type() == TaskType.TILE_UPSCALE:
602
+ if get_is_sdxl():
603
+ sdxl_tileupscaler.create(text2img_pipe)
604
+ else:
605
+ controlnet.load_model("tile_upscaler")
606
  elif task.get_type() == TaskType.CANNY:
607
  controlnet.load_model("canny")
608
  elif task.get_type() == TaskType.SCRIBBLE:
internals/pipelines/commons.py CHANGED
@@ -3,15 +3,16 @@ from typing import Any, Callable, Dict, List, Optional, Union
3
 
4
  import torch
5
  from diffusers import (
 
6
  StableDiffusionImg2ImgPipeline,
7
- StableDiffusionXLPipeline,
8
  StableDiffusionXLImg2ImgPipeline,
 
9
  )
10
 
11
  from internals.data.result import Result
12
  from internals.pipelines.twoStepPipeline import two_step_pipeline
13
  from internals.util.commons import disable_safety_checker, download_image
14
- from internals.util.config import get_hf_token, num_return_sequences, get_is_sdxl
15
 
16
 
17
  class AbstractPipeline:
@@ -32,12 +33,18 @@ class Text2Img(AbstractPipeline):
32
 
33
  def load(self, model_dir: str):
34
  if get_is_sdxl():
35
- self.pipe = StableDiffusionXLPipeline.from_pretrained(
 
 
 
36
  model_dir,
37
  torch_dtype=torch.float16,
38
  use_auth_token=get_hf_token(),
39
  use_safetensors=True,
40
- ).to("cuda")
 
 
 
41
  else:
42
  self.pipe = two_step_pipeline.from_pretrained(
43
  model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token()
 
3
 
4
  import torch
5
  from diffusers import (
6
+ AutoencoderKL,
7
  StableDiffusionImg2ImgPipeline,
 
8
  StableDiffusionXLImg2ImgPipeline,
9
+ StableDiffusionXLPipeline,
10
  )
11
 
12
  from internals.data.result import Result
13
  from internals.pipelines.twoStepPipeline import two_step_pipeline
14
  from internals.util.commons import disable_safety_checker, download_image
15
+ from internals.util.config import get_hf_token, get_is_sdxl, num_return_sequences
16
 
17
 
18
  class AbstractPipeline:
 
33
 
34
  def load(self, model_dir: str):
35
  if get_is_sdxl():
36
+ vae = AutoencoderKL.from_pretrained(
37
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
38
+ )
39
+ pipe = StableDiffusionXLPipeline.from_pretrained(
40
  model_dir,
41
  torch_dtype=torch.float16,
42
  use_auth_token=get_hf_token(),
43
  use_safetensors=True,
44
+ )
45
+ pipe.vae = vae
46
+ pipe.to("cuda")
47
+ self.pipe = pipe
48
  else:
49
  self.pipe = two_step_pipeline.from_pretrained(
50
  model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token()
internals/pipelines/demofusion_sdxl.py ADDED
The diff for this file is too large to render. See raw diff
 
internals/pipelines/sdxl_tile_upscale.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import ControlNetModel
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+
6
+ import internals.util.image as ImageUtils
7
+ from internals.data.result import Result
8
+ from internals.pipelines.commons import AbstractPipeline, Text2Img
9
+ from internals.pipelines.controlnets import ControlNet
10
+ from internals.pipelines.demofusion_sdxl import DemoFusionSDXLControlNetPipeline
11
+ from internals.util.commons import download_image
12
+ from internals.util.config import get_base_dimension
13
+
14
+ controlnet = ControlNet()
15
+
16
+
17
+ class SDXLTileUpscaler(AbstractPipeline):
18
+ def create(self, pipeline: Text2Img):
19
+ controlnet = ControlNetModel.from_pretrained(
20
+ "thibaud/controlnet-openpose-sdxl-1.0", torch_dtype=torch.float16
21
+ )
22
+ pipe = DemoFusionSDXLControlNetPipeline(
23
+ **pipeline.pipe.components, controlnet=controlnet
24
+ )
25
+ pipe = pipe.to("cuda")
26
+ pipe.enable_vae_tiling()
27
+ pipe.enable_vae_slicing()
28
+ pipe.enable_xformers_memory_efficient_attention()
29
+
30
+ self.pipe = pipe
31
+
32
+ def process(
33
+ self,
34
+ prompt: str,
35
+ imageUrl: str,
36
+ resize_dimension: int,
37
+ negative_prompt: str,
38
+ width: int,
39
+ height: int,
40
+ ):
41
+ pose_image = controlnet.detect_pose(imageUrl)
42
+ img = download_image(imageUrl).resize((width, height))
43
+
44
+ img = ImageUtils.resize_image(img, get_base_dimension())
45
+ pose_image = pose_image.resize(img.size)
46
+
47
+ img2 = self.__resize_for_condition_image(img, resize_dimension)
48
+
49
+ image_lr = self.load_and_process_image(img)
50
+ print("img", img2.size, img.size)
51
+ images = self.pipe.__call__(
52
+ image_lr=image_lr,
53
+ prompt=prompt,
54
+ condition_image=pose_image,
55
+ negative_prompt="blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
56
+ guidance_scale=11,
57
+ sigma=0.8,
58
+ num_inference_steps=24,
59
+ width=img2.size[0],
60
+ height=img2.size[1],
61
+ )
62
+ images = images[::-1]
63
+ return images, False
64
+
65
+ def load_and_process_image(self, pil_image):
66
+ transform = transforms.Compose(
67
+ [
68
+ transforms.Resize((1024, 1024)),
69
+ transforms.ToTensor(),
70
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
71
+ ]
72
+ )
73
+ image = transform(pil_image)
74
+ image = image.unsqueeze(0).half()
75
+ image = image.to("cuda")
76
+ return image
77
+
78
+ def __resize_for_condition_image(self, image: Image.Image, resolution: int):
79
+ input_image = image.convert("RGB")
80
+ W, H = input_image.size
81
+ k = float(resolution) / max(W, H)
82
+ H *= k
83
+ W *= k
84
+ H = int(round(H / 64.0)) * 64
85
+ W = int(round(W / 64.0)) * 64
86
+ img = input_image.resize((W, H), resample=Image.LANCZOS)
87
+ return img