jayparmr commited on
Commit
10230ea
·
1 Parent(s): fd5252e

Upload folder using huggingface_hub

Browse files
inference.py CHANGED
@@ -1,10 +1,11 @@
1
  import os
2
  from typing import List, Optional
3
 
 
4
  import torch
5
 
6
  import internals.util.prompt as prompt_util
7
- from internals.data.dataAccessor import update_db
8
  from internals.data.task import Task, TaskType
9
  from internals.pipelines.commons import Img2Img, Text2Img
10
  from internals.pipelines.controlnets import ControlNet
@@ -18,11 +19,15 @@ from internals.pipelines.replace_background import ReplaceBackground
18
  from internals.pipelines.safety_checker import SafetyChecker
19
  from internals.util.args import apply_style_args
20
  from internals.util.avatar import Avatar
21
- from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
22
  from internals.util.commons import download_image, upload_image, upload_images
23
- from internals.util.config import (get_model_dir, num_return_sequences,
24
- set_configs_from_task, set_model_config,
25
- set_root_dir)
 
 
 
 
26
  from internals.util.failure_hander import FailureHandler
27
  from internals.util.lora_style import LoraStyle
28
  from internals.util.model_loader import load_model_from_config
@@ -80,7 +85,7 @@ def canny(task: Task):
80
 
81
  width, height = get_intermediate_dimension(task)
82
 
83
- controlnet.load_canny()
84
 
85
  # pipe2 is used for canny and pose
86
  lora_patcher = lora_style.get_patcher(
@@ -88,7 +93,7 @@ def canny(task: Task):
88
  )
89
  lora_patcher.patch()
90
 
91
- images, has_nsfw = controlnet.process_canny(
92
  prompt=prompt,
93
  imageUrl=task.get_imageUrl(),
94
  seed=task.get_seed(),
@@ -132,12 +137,12 @@ def tile_upscale(task: Task):
132
 
133
  prompt = get_patched_prompt_tile_upscale(task)
134
 
135
- controlnet.load_tile_upscaler()
136
 
137
  lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
138
  lora_patcher.patch()
139
 
140
- images, has_nsfw = controlnet.process_tile_upscaler(
141
  imageUrl=task.get_imageUrl(),
142
  seed=task.get_seed(),
143
  steps=task.get_steps(),
@@ -169,14 +174,14 @@ def scribble(task: Task):
169
 
170
  width, height = get_intermediate_dimension(task)
171
 
172
- controlnet.load_scribble()
173
 
174
  lora_patcher = lora_style.get_patcher(
175
  [controlnet.pipe2, high_res.pipe], task.get_style()
176
  )
177
  lora_patcher.patch()
178
 
179
- images, has_nsfw = controlnet.process_scribble(
180
  imageUrl=task.get_imageUrl(),
181
  seed=task.get_seed(),
182
  steps=task.get_steps(),
@@ -215,14 +220,14 @@ def linearart(task: Task):
215
 
216
  width, height = get_intermediate_dimension(task)
217
 
218
- controlnet.load_linearart()
219
 
220
  lora_patcher = lora_style.get_patcher(
221
  [controlnet.pipe2, high_res.pipe], task.get_style()
222
  )
223
  lora_patcher.patch()
224
 
225
- images, has_nsfw = controlnet.process_linearart(
226
  imageUrl=task.get_imageUrl(),
227
  seed=task.get_seed(),
228
  steps=task.get_steps(),
@@ -261,7 +266,7 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
261
 
262
  width, height = get_intermediate_dimension(task)
263
 
264
- controlnet.load_pose()
265
 
266
  # pipe2 is used for canny and pose
267
  lora_patcher = lora_style.get_patcher(
@@ -291,7 +296,7 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
291
  )
292
  condition_image = ControlNet.linearart_condition_image(src_image)
293
 
294
- images, has_nsfw = controlnet.process_pose(
295
  prompt=prompt,
296
  image=poses,
297
  condition_image=[condition_image] * num_return_sequences,
@@ -440,7 +445,7 @@ def inpaint(task: Task):
440
 
441
  generated_image_urls = upload_images(images, "_inpaint", task.get_taskId())
442
 
443
- clear_cuda()
444
 
445
  return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
446
 
@@ -469,12 +474,13 @@ def replace_bg(task: Task):
469
  product_scale_width=task.get_image_scale(),
470
  apply_high_res=task.get_high_res_fix(),
471
  conditioning_scale=task.rbg_controlnet_conditioning_scale(),
 
472
  )
473
 
474
  generated_image_urls = upload_images(images, "_replace_bg", task.get_taskId())
475
 
476
  lora_patcher.cleanup()
477
- clear_cuda()
478
 
479
  return {
480
  "modified_prompts": prompt,
@@ -484,38 +490,33 @@ def replace_bg(task: Task):
484
 
485
 
486
  def load_model_by_task(task: Task):
487
- high_res.load()
488
-
489
- if (
490
- task.get_type()
491
- in [
492
- TaskType.TEXT_TO_IMAGE,
493
- TaskType.IMAGE_TO_IMAGE,
494
- TaskType.INPAINT,
495
- ]
496
- and not text2img_pipe.is_loaded()
497
- ):
498
  text2img_pipe.load(get_model_dir())
499
  img2img_pipe.create(text2img_pipe)
500
- inpainter.load()
501
  high_res.load(img2img_pipe)
502
 
 
 
 
503
  safety_checker.apply(text2img_pipe)
504
  safety_checker.apply(img2img_pipe)
 
 
 
505
  safety_checker.apply(inpainter)
506
  elif task.get_type() == TaskType.REPLACE_BG:
507
  replace_background.load(inpainter=inpainter, high_res=high_res)
508
  else:
509
  if task.get_type() == TaskType.TILE_UPSCALE:
510
- controlnet.load_tile_upscaler()
511
  elif task.get_type() == TaskType.CANNY:
512
- controlnet.load_canny()
513
  elif task.get_type() == TaskType.SCRIBBLE:
514
- controlnet.load_scribble()
515
  elif task.get_type() == TaskType.LINEARART:
516
- controlnet.load_linearart()
517
  elif task.get_type() == TaskType.POSE:
518
- controlnet.load_pose()
519
 
520
  safety_checker.apply(controlnet)
521
 
@@ -589,7 +590,8 @@ def predict_fn(data, pipe):
589
  else:
590
  raise Exception("Invalid task type")
591
  except Exception as e:
592
- print(f"Error: {e}")
593
  slack.error_alert(task, e)
594
  controlnet.cleanup()
 
 
595
  return None
 
1
  import os
2
  from typing import List, Optional
3
 
4
+ import traceback
5
  import torch
6
 
7
  import internals.util.prompt as prompt_util
8
+ from internals.data.dataAccessor import update_db, update_db_source_failed
9
  from internals.data.task import Task, TaskType
10
  from internals.pipelines.commons import Img2Img, Text2Img
11
  from internals.pipelines.controlnets import ControlNet
 
19
  from internals.pipelines.safety_checker import SafetyChecker
20
  from internals.util.args import apply_style_args
21
  from internals.util.avatar import Avatar
22
+ from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda, clear_cuda_and_gc
23
  from internals.util.commons import download_image, upload_image, upload_images
24
+ from internals.util.config import (
25
+ get_model_dir,
26
+ num_return_sequences,
27
+ set_configs_from_task,
28
+ set_model_config,
29
+ set_root_dir,
30
+ )
31
  from internals.util.failure_hander import FailureHandler
32
  from internals.util.lora_style import LoraStyle
33
  from internals.util.model_loader import load_model_from_config
 
85
 
86
  width, height = get_intermediate_dimension(task)
87
 
88
+ controlnet.load_model("canny")
89
 
90
  # pipe2 is used for canny and pose
91
  lora_patcher = lora_style.get_patcher(
 
93
  )
94
  lora_patcher.patch()
95
 
96
+ images, has_nsfw = controlnet.process(
97
  prompt=prompt,
98
  imageUrl=task.get_imageUrl(),
99
  seed=task.get_seed(),
 
137
 
138
  prompt = get_patched_prompt_tile_upscale(task)
139
 
140
+ controlnet.load_model("tile_upscaler")
141
 
142
  lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
143
  lora_patcher.patch()
144
 
145
+ images, has_nsfw = controlnet.process(
146
  imageUrl=task.get_imageUrl(),
147
  seed=task.get_seed(),
148
  steps=task.get_steps(),
 
174
 
175
  width, height = get_intermediate_dimension(task)
176
 
177
+ controlnet.load_model("scribble")
178
 
179
  lora_patcher = lora_style.get_patcher(
180
  [controlnet.pipe2, high_res.pipe], task.get_style()
181
  )
182
  lora_patcher.patch()
183
 
184
+ images, has_nsfw = controlnet.process(
185
  imageUrl=task.get_imageUrl(),
186
  seed=task.get_seed(),
187
  steps=task.get_steps(),
 
220
 
221
  width, height = get_intermediate_dimension(task)
222
 
223
+ controlnet.load_model("linearart")
224
 
225
  lora_patcher = lora_style.get_patcher(
226
  [controlnet.pipe2, high_res.pipe], task.get_style()
227
  )
228
  lora_patcher.patch()
229
 
230
+ images, has_nsfw = controlnet.process(
231
  imageUrl=task.get_imageUrl(),
232
  seed=task.get_seed(),
233
  steps=task.get_steps(),
 
266
 
267
  width, height = get_intermediate_dimension(task)
268
 
269
+ controlnet.load_model("pose")
270
 
271
  # pipe2 is used for canny and pose
272
  lora_patcher = lora_style.get_patcher(
 
296
  )
297
  condition_image = ControlNet.linearart_condition_image(src_image)
298
 
299
+ images, has_nsfw = controlnet.process(
300
  prompt=prompt,
301
  image=poses,
302
  condition_image=[condition_image] * num_return_sequences,
 
445
 
446
  generated_image_urls = upload_images(images, "_inpaint", task.get_taskId())
447
 
448
+ clear_cuda_and_gc()
449
 
450
  return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
451
 
 
474
  product_scale_width=task.get_image_scale(),
475
  apply_high_res=task.get_high_res_fix(),
476
  conditioning_scale=task.rbg_controlnet_conditioning_scale(),
477
+ model_type=task.get_modelType(),
478
  )
479
 
480
  generated_image_urls = upload_images(images, "_replace_bg", task.get_taskId())
481
 
482
  lora_patcher.cleanup()
483
+ clear_cuda_and_gc()
484
 
485
  return {
486
  "modified_prompts": prompt,
 
490
 
491
 
492
  def load_model_by_task(task: Task):
493
+ if not text2img_pipe.is_loaded():
 
 
 
 
 
 
 
 
 
 
494
  text2img_pipe.load(get_model_dir())
495
  img2img_pipe.create(text2img_pipe)
 
496
  high_res.load(img2img_pipe)
497
 
498
+ inpainter.init(text2img_pipe)
499
+ controlnet.init(text2img_pipe)
500
+
501
  safety_checker.apply(text2img_pipe)
502
  safety_checker.apply(img2img_pipe)
503
+
504
+ if task.get_type() == TaskType.INPAINT:
505
+ inpainter.load()
506
  safety_checker.apply(inpainter)
507
  elif task.get_type() == TaskType.REPLACE_BG:
508
  replace_background.load(inpainter=inpainter, high_res=high_res)
509
  else:
510
  if task.get_type() == TaskType.TILE_UPSCALE:
511
+ controlnet.load_model("tile_upscaler")
512
  elif task.get_type() == TaskType.CANNY:
513
+ controlnet.load_model("canny")
514
  elif task.get_type() == TaskType.SCRIBBLE:
515
+ controlnet.load_model("scribble")
516
  elif task.get_type() == TaskType.LINEARART:
517
+ controlnet.load_model("linearart")
518
  elif task.get_type() == TaskType.POSE:
519
+ controlnet.load_model("pose")
520
 
521
  safety_checker.apply(controlnet)
522
 
 
590
  else:
591
  raise Exception("Invalid task type")
592
  except Exception as e:
 
593
  slack.error_alert(task, e)
594
  controlnet.cleanup()
595
+ traceback.print_exc()
596
+ update_db_source_failed(task.get_sourceId(), task.get_userId())
597
  return None
inference2.py CHANGED
@@ -13,17 +13,19 @@ from internals.pipelines.img_to_text import Image2Text
13
  from internals.pipelines.inpainter import InPainter
14
  from internals.pipelines.object_remove import ObjectRemoval
15
  from internals.pipelines.prompt_modifier import PromptModifier
16
- from internals.pipelines.remove_background import (RemoveBackground,
17
- RemoveBackgroundV2)
18
  from internals.pipelines.replace_background import ReplaceBackground
19
  from internals.pipelines.safety_checker import SafetyChecker
20
  from internals.pipelines.upscaler import Upscaler
21
  from internals.util.avatar import Avatar
22
  from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
23
- from internals.util.commons import (construct_default_s3_url, upload_image,
24
- upload_images)
25
- from internals.util.config import (num_return_sequences, set_configs_from_task,
26
- set_model_config, set_root_dir)
 
 
 
27
  from internals.util.failure_hander import FailureHandler
28
  from internals.util.lora_style import LoraStyle
29
  from internals.util.model_loader import load_model_from_config
@@ -65,7 +67,7 @@ def tile_upscale(task: Task):
65
 
66
  prompt = get_patched_prompt_tile_upscale(task)
67
 
68
- controlnet.load_tile_upscaler()
69
 
70
  lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
71
  lora_patcher.patch()
@@ -98,7 +100,9 @@ def tile_upscale(task: Task):
98
  @slack.auto_send_alert
99
  def remove_bg(task: Task):
100
  # remove_background = RemoveBackground()
101
- output_image = remove_background_v2.remove(task.get_imageUrl())
 
 
102
 
103
  output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId())
104
  upload_image(output_image, output_key)
@@ -173,6 +177,7 @@ def replace_bg(task: Task):
173
  extend_object=task.rbg_extend_object(),
174
  product_scale_width=task.get_image_scale(),
175
  conditioning_scale=task.rbg_controlnet_conditioning_scale(),
 
176
  )
177
 
178
  generated_image_urls = upload_images(images, "_replace_bg", task.get_taskId())
@@ -231,6 +236,7 @@ def model_fn(model_dir):
231
  upscaler.load()
232
  inpainter.load()
233
  high_res.load()
 
234
 
235
  replace_background.load(
236
  upscaler=upscaler, remove_background=remove_background_v2, high_res=high_res
@@ -242,7 +248,7 @@ def model_fn(model_dir):
242
 
243
  def load_model_by_task(task: Task):
244
  if task.get_type() == TaskType.TILE_UPSCALE:
245
- controlnet.load_tile_upscaler()
246
 
247
  safety_checker.apply(controlnet)
248
 
 
13
  from internals.pipelines.inpainter import InPainter
14
  from internals.pipelines.object_remove import ObjectRemoval
15
  from internals.pipelines.prompt_modifier import PromptModifier
16
+ from internals.pipelines.remove_background import RemoveBackground, RemoveBackgroundV2
 
17
  from internals.pipelines.replace_background import ReplaceBackground
18
  from internals.pipelines.safety_checker import SafetyChecker
19
  from internals.pipelines.upscaler import Upscaler
20
  from internals.util.avatar import Avatar
21
  from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
22
+ from internals.util.commons import construct_default_s3_url, upload_image, upload_images
23
+ from internals.util.config import (
24
+ num_return_sequences,
25
+ set_configs_from_task,
26
+ set_model_config,
27
+ set_root_dir,
28
+ )
29
  from internals.util.failure_hander import FailureHandler
30
  from internals.util.lora_style import LoraStyle
31
  from internals.util.model_loader import load_model_from_config
 
67
 
68
  prompt = get_patched_prompt_tile_upscale(task)
69
 
70
+ controlnet.load_model("tile_upscaler")
71
 
72
  lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
73
  lora_patcher.patch()
 
100
  @slack.auto_send_alert
101
  def remove_bg(task: Task):
102
  # remove_background = RemoveBackground()
103
+ output_image = remove_background_v2.remove(
104
+ task.get_imageUrl(), model_type=task.get_modelType()
105
+ )
106
 
107
  output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId())
108
  upload_image(output_image, output_key)
 
177
  extend_object=task.rbg_extend_object(),
178
  product_scale_width=task.get_image_scale(),
179
  conditioning_scale=task.rbg_controlnet_conditioning_scale(),
180
+ model_type=task.get_modelType(),
181
  )
182
 
183
  generated_image_urls = upload_images(images, "_replace_bg", task.get_taskId())
 
236
  upscaler.load()
237
  inpainter.load()
238
  high_res.load()
239
+ controlnet.init(high_res)
240
 
241
  replace_background.load(
242
  upscaler=upscaler, remove_background=remove_background_v2, high_res=high_res
 
248
 
249
  def load_model_by_task(task: Task):
250
  if task.get_type() == TaskType.TILE_UPSCALE:
251
+ controlnet.load_model("tile_upscaler")
252
 
253
  safety_checker.apply(controlnet)
254
 
internals/data/dataAccessor.py CHANGED
@@ -1,6 +1,7 @@
1
  import traceback
2
  from typing import Dict, List, Optional
3
 
 
4
  import requests
5
  from pydash import includes
6
 
@@ -9,6 +10,14 @@ from internals.util.config import api_endpoint, api_headers
9
  from internals.util.slack import Slack
10
 
11
 
 
 
 
 
 
 
 
 
12
  def updateSource(sourceId, userId, state):
13
  print("update source is called")
14
  url = api_endpoint() + f"/autodraft-crecoai/source/{sourceId}"
@@ -21,7 +30,8 @@ def updateSource(sourceId, userId, state):
21
  data = {"state": state}
22
 
23
  try:
24
- response = requests.patch(url, headers=headers, json=data, timeout=10)
 
25
  print("update source response", response)
26
  except requests.exceptions.Timeout:
27
  print("Request timed out while updating source")
@@ -47,7 +57,8 @@ def saveGeneratedImages(sourceId, userId, has_nsfw: bool):
47
  data = {"state": "ACTIVE", "has_nsfw": has_nsfw}
48
 
49
  try:
50
- requests.patch(url, headers=headers, json=data)
 
51
  # print("save generation response", response)
52
  except requests.exceptions.Timeout:
53
  print("Request timed out while saving image")
@@ -61,11 +72,12 @@ def getStyles() -> Optional[Dict]:
61
  url = api_endpoint() + "/autodraft-crecoai/style"
62
  print(url)
63
  try:
64
- response = requests.get(
65
- url,
66
- timeout=10,
67
- headers={"x-api-key": "kGyEMp)oHB(zf^E5>-{o]I%go", **api_headers()},
68
- )
 
69
  return response.json()
70
  except requests.exceptions.Timeout:
71
  print("Request timed out while fetching styles")
@@ -78,9 +90,10 @@ def getStyles() -> Optional[Dict]:
78
  def getCharacters(model_id: str) -> Optional[List]:
79
  url = api_endpoint() + "/autodraft-crecoai/model/{}".format(model_id)
80
  try:
81
- response = requests.get(url, timeout=10, headers=api_headers())
82
- response = response.json()
83
- response = response["data"]["characters"]
 
84
  return response
85
  except requests.exceptions.Timeout:
86
  print("Request timed out while fetching characters")
@@ -89,6 +102,10 @@ def getCharacters(model_id: str) -> Optional[List]:
89
  return None
90
 
91
 
 
 
 
 
92
  def update_db(func):
93
  def caller(*args, **kwargs):
94
  if type(args[0]) is not Task:
 
1
  import traceback
2
  from typing import Dict, List, Optional
3
 
4
+ from requests.adapters import Retry, HTTPAdapter
5
  import requests
6
  from pydash import includes
7
 
 
10
  from internals.util.slack import Slack
11
 
12
 
13
+ class RetryRequest:
14
+ def __new__(cls):
15
+ obj = Retry(total=5, backoff_factor=2, status_forcelist=[500, 502, 503, 504])
16
+ session = requests.Session()
17
+ session.mount("https://", HTTPAdapter(max_retries=obj))
18
+ return session
19
+
20
+
21
  def updateSource(sourceId, userId, state):
22
  print("update source is called")
23
  url = api_endpoint() + f"/autodraft-crecoai/source/{sourceId}"
 
30
  data = {"state": state}
31
 
32
  try:
33
+ with RetryRequest() as session:
34
+ response = session.patch(url, headers=headers, json=data, timeout=10)
35
  print("update source response", response)
36
  except requests.exceptions.Timeout:
37
  print("Request timed out while updating source")
 
57
  data = {"state": "ACTIVE", "has_nsfw": has_nsfw}
58
 
59
  try:
60
+ with RetryRequest() as session:
61
+ session.patch(url, headers=headers, json=data)
62
  # print("save generation response", response)
63
  except requests.exceptions.Timeout:
64
  print("Request timed out while saving image")
 
72
  url = api_endpoint() + "/autodraft-crecoai/style"
73
  print(url)
74
  try:
75
+ with RetryRequest() as session:
76
+ response = session.get(
77
+ url,
78
+ timeout=10,
79
+ headers={"x-api-key": "kGyEMp)oHB(zf^E5>-{o]I%go", **api_headers()},
80
+ )
81
  return response.json()
82
  except requests.exceptions.Timeout:
83
  print("Request timed out while fetching styles")
 
90
  def getCharacters(model_id: str) -> Optional[List]:
91
  url = api_endpoint() + "/autodraft-crecoai/model/{}".format(model_id)
92
  try:
93
+ with RetryRequest() as session:
94
+ response = session.get(url, timeout=10, headers=api_headers())
95
+ response = response.json()
96
+ response = response["data"]["characters"]
97
  return response
98
  except requests.exceptions.Timeout:
99
  print("Request timed out while fetching characters")
 
102
  return None
103
 
104
 
105
+ def update_db_source_failed(sourceId, userId):
106
+ updateSource(sourceId, userId, "FAILED")
107
+
108
+
109
  def update_db(func):
110
  def caller(*args, **kwargs):
111
  if type(args[0]) is not Task:
internals/pipelines/commons.py CHANGED
@@ -2,12 +2,16 @@ from dataclasses import dataclass
2
  from typing import Any, Callable, Dict, List, Optional, Union
3
 
4
  import torch
5
- from diffusers import StableDiffusionImg2ImgPipeline
 
 
 
 
6
 
7
  from internals.data.result import Result
8
  from internals.pipelines.twoStepPipeline import two_step_pipeline
9
  from internals.util.commons import disable_safety_checker, download_image
10
- from internals.util.config import get_hf_token, num_return_sequences
11
 
12
 
13
  class AbstractPipeline:
@@ -27,9 +31,17 @@ class Text2Img(AbstractPipeline):
27
  prompt_right: List[str] = None
28
 
29
  def load(self, model_dir: str):
30
- self.pipe = two_step_pipeline.from_pretrained(
31
- model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token()
32
- ).to("cuda")
 
 
 
 
 
 
 
 
33
  self.__patch()
34
 
35
  def is_loaded(self):
@@ -38,10 +50,16 @@ class Text2Img(AbstractPipeline):
38
  return False
39
 
40
  def create(self, pipeline: AbstractPipeline):
41
- self.pipe = two_step_pipeline(**pipeline.pipe.components).to("cuda")
 
 
 
42
  self.__patch()
43
 
44
  def __patch(self):
 
 
 
45
  self.pipe.enable_xformers_memory_efficient_attention()
46
 
47
  @torch.inference_mode()
@@ -92,9 +110,19 @@ class Text2Img(AbstractPipeline):
92
  # two step pipeline
93
  modified_prompt = params.modified_prompt
94
 
95
- result = self.pipe.two_step_pipeline(
96
- prompt=prompt,
97
- modified_prompts=modified_prompt,
 
 
 
 
 
 
 
 
 
 
98
  height=height,
99
  width=width,
100
  num_inference_steps=num_inference_steps,
@@ -111,7 +139,7 @@ class Text2Img(AbstractPipeline):
111
  callback=callback,
112
  callback_steps=callback_steps,
113
  cross_attention_kwargs=cross_attention_kwargs,
114
- iteration=iteration,
115
  )
116
 
117
  return Result.from_result(result)
@@ -124,22 +152,38 @@ class Img2Img(AbstractPipeline):
124
  if self.__loaded:
125
  return
126
 
127
- self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
128
- model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token()
129
- ).to("cuda")
 
 
 
 
 
 
 
 
130
  self.__patch()
131
 
132
  self.__loaded = True
133
 
134
  def create(self, pipeline: AbstractPipeline):
135
- self.pipe = StableDiffusionImg2ImgPipeline(**pipeline.pipe.components).to(
136
- "cuda"
137
- )
 
 
 
 
 
138
  self.__patch()
139
 
140
  self.__loaded = True
141
 
142
  def __patch(self):
 
 
 
143
  self.pipe.enable_xformers_memory_efficient_attention()
144
 
145
  @torch.inference_mode()
 
2
  from typing import Any, Callable, Dict, List, Optional, Union
3
 
4
  import torch
5
+ from diffusers import (
6
+ StableDiffusionImg2ImgPipeline,
7
+ StableDiffusionXLPipeline,
8
+ StableDiffusionXLImg2ImgPipeline,
9
+ )
10
 
11
  from internals.data.result import Result
12
  from internals.pipelines.twoStepPipeline import two_step_pipeline
13
  from internals.util.commons import disable_safety_checker, download_image
14
+ from internals.util.config import get_hf_token, num_return_sequences, get_is_sdxl
15
 
16
 
17
  class AbstractPipeline:
 
31
  prompt_right: List[str] = None
32
 
33
  def load(self, model_dir: str):
34
+ if get_is_sdxl():
35
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(
36
+ model_dir,
37
+ torch_dtype=torch.float16,
38
+ use_auth_token=get_hf_token(),
39
+ use_safetensors=True,
40
+ ).to("cuda")
41
+ else:
42
+ self.pipe = two_step_pipeline.from_pretrained(
43
+ model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token()
44
+ ).to("cuda")
45
  self.__patch()
46
 
47
  def is_loaded(self):
 
50
  return False
51
 
52
  def create(self, pipeline: AbstractPipeline):
53
+ if get_is_sdxl():
54
+ self.pipe = StableDiffusionXLPipeline(**pipeline.pipe.components).to("cuda")
55
+ else:
56
+ self.pipe = two_step_pipeline(**pipeline.pipe.components).to("cuda")
57
  self.__patch()
58
 
59
  def __patch(self):
60
+ if get_is_sdxl():
61
+ self.pipe.enable_vae_tiling()
62
+ self.pipe.enable_vae_slicing()
63
  self.pipe.enable_xformers_memory_efficient_attention()
64
 
65
  @torch.inference_mode()
 
110
  # two step pipeline
111
  modified_prompt = params.modified_prompt
112
 
113
+ if get_is_sdxl():
114
+ print("Warning: Two step pipeline is not supported on SDXL")
115
+ kwargs = {
116
+ "prompt": modified_prompt,
117
+ }
118
+ else:
119
+ kwargs = {
120
+ "prompt": prompt,
121
+ "modified_prompts": modified_prompt,
122
+ "iteration": iteration,
123
+ }
124
+
125
+ result = self.pipe.__call__(
126
  height=height,
127
  width=width,
128
  num_inference_steps=num_inference_steps,
 
139
  callback=callback,
140
  callback_steps=callback_steps,
141
  cross_attention_kwargs=cross_attention_kwargs,
142
+ **kwargs
143
  )
144
 
145
  return Result.from_result(result)
 
152
  if self.__loaded:
153
  return
154
 
155
+ if get_is_sdxl():
156
+ self.pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
157
+ model_dir,
158
+ torch_dtype=torch.float16,
159
+ use_auth_token=get_hf_token(),
160
+ use_safetensors=True,
161
+ ).to("cuda")
162
+ else:
163
+ self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
164
+ model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token()
165
+ ).to("cuda")
166
  self.__patch()
167
 
168
  self.__loaded = True
169
 
170
  def create(self, pipeline: AbstractPipeline):
171
+ if get_is_sdxl():
172
+ self.pipe = StableDiffusionXLImg2ImgPipeline(**pipeline.pipe.components).to(
173
+ "cuda"
174
+ )
175
+ else:
176
+ self.pipe = StableDiffusionImg2ImgPipeline(**pipeline.pipe.components).to(
177
+ "cuda"
178
+ )
179
  self.__patch()
180
 
181
  self.__loaded = True
182
 
183
  def __patch(self):
184
+ if get_is_sdxl():
185
+ self.pipe.enable_vae_tiling()
186
+ self.pipe.enable_vae_slicing()
187
  self.pipe.enable_xformers_memory_efficient_attention()
188
 
189
  @torch.inference_mode()
internals/pipelines/controlnets.py CHANGED
@@ -1,14 +1,20 @@
1
- from typing import List, Union
2
 
3
  import cv2
4
  import numpy as np
 
5
  import torch
6
  from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector
7
- from diffusers import (ControlNetModel, DiffusionPipeline,
8
- StableDiffusionControlNetPipeline,
9
- UniPCMultistepScheduler)
10
- from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import \
11
- MultiControlNetModel
 
 
 
 
 
12
  from PIL import Image
13
  from torch.nn import Linear
14
  from tqdm import gui
@@ -18,156 +24,127 @@ import internals.util.image as ImageUtil
18
  from external.midas import apply_midas
19
  from internals.data.result import Result
20
  from internals.pipelines.commons import AbstractPipeline
21
- from internals.pipelines.tileUpscalePipeline import \
22
- StableDiffusionControlNetImg2ImgPipeline
 
23
  from internals.util.cache import clear_cuda_and_gc
24
  from internals.util.commons import download_image
25
- from internals.util.config import get_hf_cache_dir, get_hf_token, get_model_dir
 
 
 
 
 
 
 
 
26
 
27
 
28
  class ControlNet(AbstractPipeline):
29
  __current_task_name = ""
30
  __loaded = False
31
 
32
- def load(self):
33
- "Should not be called externally"
34
- if self.__loaded:
35
- return
36
 
37
- if not hasattr(self, "controlnet"):
38
- self.load_pose()
39
 
40
- # controlnet pipeline for tile upscaler
41
- pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
42
- get_model_dir(),
43
- controlnet=self.controlnet,
44
- torch_dtype=torch.float16,
45
- use_auth_token=get_hf_token(),
46
- cache_dir=get_hf_cache_dir(),
47
- ).to("cuda")
48
- # pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
49
- pipe.enable_model_cpu_offload()
50
- pipe.enable_xformers_memory_efficient_attention()
51
- self.pipe = pipe
52
-
53
- # controlnet pipeline for canny and pose
54
- pipe2 = StableDiffusionControlNetPipeline(**pipe.components).to("cuda")
55
- pipe2.scheduler = UniPCMultistepScheduler.from_config(pipe2.scheduler.config)
56
- pipe2.enable_xformers_memory_efficient_attention()
57
- self.pipe2 = pipe2
58
-
59
- self.__loaded = True
60
-
61
- def load_canny(self):
62
- if self.__current_task_name == "canny":
63
  return
64
- canny = ControlNetModel.from_pretrained(
65
- "lllyasviel/control_v11p_sd15_canny",
 
 
 
 
 
 
 
66
  torch_dtype=torch.float16,
67
  cache_dir=get_hf_cache_dir(),
68
  ).to("cuda")
69
- self.__current_task_name = "canny"
70
- self.controlnet = canny
71
 
72
- self.load()
73
 
74
  if hasattr(self, "pipe"):
75
- self.pipe.controlnet = canny
76
  if hasattr(self, "pipe2"):
77
- self.pipe2.controlnet = canny
78
  clear_cuda_and_gc()
79
 
80
- def load_pose(self):
81
- if self.__current_task_name == "pose":
 
82
  return
83
- pose = ControlNetModel.from_pretrained(
84
- "lllyasviel/control_v11p_sd15_openpose",
85
- torch_dtype=torch.float16,
86
- cache_dir=get_hf_cache_dir(),
87
- ).to("cuda")
88
- # lineart = ControlNetModel.from_pretrained(
89
- # "ControlNet-1-1-preview/control_v11p_sd15_lineart",
90
- # torch_dtype=torch.float16,
91
- # cache_dir=get_hf_cache_dir(),
92
- # ).to("cuda")
93
- self.__current_task_name = "pose"
94
- self.controlnet = MultiControlNetModel([pose]).to("cuda")
95
-
96
- self.load()
97
 
98
- if hasattr(self, "pipe"):
99
- self.pipe.controlnet = self.controlnet
100
- if hasattr(self, "pipe2"):
101
- self.pipe2.controlnet = self.controlnet
102
- clear_cuda_and_gc()
103
-
104
- def load_tile_upscaler(self):
105
- if self.__current_task_name == "tile_upscaler":
106
- return
107
- tile_upscaler = ControlNetModel.from_pretrained(
108
- "lllyasviel/control_v11f1e_sd15_tile",
109
- torch_dtype=torch.float16,
110
- cache_dir=get_hf_cache_dir(),
111
- ).to("cuda")
112
- self.__current_task_name = "tile_upscaler"
113
- self.controlnet = tile_upscaler
114
 
115
- self.load()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
- if hasattr(self, "pipe"):
118
- self.pipe.controlnet = tile_upscaler
119
- if hasattr(self, "pipe2"):
120
- self.pipe2.controlnet = tile_upscaler
121
- clear_cuda_and_gc()
122
 
123
- def load_scribble(self):
 
 
 
 
124
  if self.__current_task_name == "scribble":
125
- return
126
- scribble = ControlNetModel.from_pretrained(
127
- "lllyasviel/control_v11p_sd15_scribble",
128
- torch_dtype=torch.float16,
129
- cache_dir=get_hf_cache_dir(),
130
- ).to("cuda")
131
- self.__current_task_name = "scribble"
132
- self.controlnet = scribble
133
-
134
- self.load()
135
-
136
- if hasattr(self, "pipe"):
137
- self.pipe.controlnet = scribble
138
- if hasattr(self, "pipe2"):
139
- self.pipe2.controlnet = scribble
140
- clear_cuda_and_gc()
141
-
142
- def load_linearart(self):
143
  if self.__current_task_name == "linearart":
144
- return
145
- linearart = ControlNetModel.from_pretrained(
146
- "ControlNet-1-1-preview/control_v11p_sd15_lineart",
147
- torch_dtype=torch.float16,
148
- cache_dir=get_hf_cache_dir(),
149
- ).to("cuda")
150
- self.__current_task_name = "linearart"
151
- self.controlnet = linearart
152
-
153
- self.load()
154
-
155
- if hasattr(self, "pipe"):
156
- self.pipe.controlnet = linearart
157
- if hasattr(self, "pipe2"):
158
- self.pipe2.controlnet = linearart
159
- clear_cuda_and_gc()
160
-
161
- def cleanup(self):
162
- if hasattr(self, "pipe"):
163
- self.pipe.controlnet = None
164
- if hasattr(self, "pipe2"):
165
- self.pipe2.controlnet = None
166
- self.controlnet = None
167
- del self.controlnet
168
- self.__current_task_name = ""
169
-
170
- clear_cuda_and_gc()
171
 
172
  @torch.inference_mode()
173
  def process_canny(
@@ -228,7 +205,6 @@ class ControlNet(AbstractPipeline):
228
  guidance_scale=guidance_scale,
229
  height=height,
230
  width=width,
231
- controlnet_conditioning_scale=[1.0],
232
  )
233
  return Result.from_result(result)
234
 
@@ -333,6 +309,17 @@ class ControlNet(AbstractPipeline):
333
  )
334
  return Result.from_result(result)
335
 
 
 
 
 
 
 
 
 
 
 
 
336
  def detect_pose(self, imageUrl: str) -> Image.Image:
337
  detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
338
  image = download_image(imageUrl)
@@ -381,3 +368,18 @@ class ControlNet(AbstractPipeline):
381
  W = int(round(W / 64.0)) * 64
382
  img = input_image.resize((W, H), resample=Image.LANCZOS)
383
  return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Literal, Union
2
 
3
  import cv2
4
  import numpy as np
5
+ from pydash import has
6
  import torch
7
  from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector
8
+ from diffusers import (
9
+ ControlNetModel,
10
+ DiffusionPipeline,
11
+ StableDiffusionControlNetPipeline,
12
+ UniPCMultistepScheduler,
13
+ StableDiffusionXLControlNetPipeline,
14
+ )
15
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import (
16
+ MultiControlNetModel,
17
+ )
18
  from PIL import Image
19
  from torch.nn import Linear
20
  from tqdm import gui
 
24
  from external.midas import apply_midas
25
  from internals.data.result import Result
26
  from internals.pipelines.commons import AbstractPipeline
27
+ from internals.pipelines.tileUpscalePipeline import (
28
+ StableDiffusionControlNetImg2ImgPipeline,
29
+ )
30
  from internals.util.cache import clear_cuda_and_gc
31
  from internals.util.commons import download_image
32
+ from internals.util.config import (
33
+ get_hf_cache_dir,
34
+ get_hf_token,
35
+ get_model_dir,
36
+ get_is_sdxl,
37
+ )
38
+
39
+
40
+ CONTROLNET_TYPES = Literal["pose", "canny", "scribble", "linearart", "tile_upscaler"]
41
 
42
 
43
  class ControlNet(AbstractPipeline):
44
  __current_task_name = ""
45
  __loaded = False
46
 
47
+ __pipeline: AbstractPipeline
 
 
 
48
 
49
+ def init(self, pipeline: AbstractPipeline):
50
+ self.__pipeline = pipeline
51
 
52
+ def load_model(self, task_name: CONTROLNET_TYPES):
53
+ config = self.__model_sdxl if get_is_sdxl() else self.__model_normal
54
+ if self.__current_task_name == task_name:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  return
56
+ model = config[task_name]
57
+ if not model:
58
+ raise Exception(f"ControlNet is not supported for {task_name}")
59
+ while model in list(config.keys()):
60
+ task_name = config[model] # pyright: ignore
61
+ model = config[task_name]
62
+
63
+ controlnet = ControlNetModel.from_pretrained(
64
+ model,
65
  torch_dtype=torch.float16,
66
  cache_dir=get_hf_cache_dir(),
67
  ).to("cuda")
68
+ self.__current_task_name = task_name
69
+ self.controlnet = controlnet
70
 
71
+ self.__load()
72
 
73
  if hasattr(self, "pipe"):
74
+ self.pipe.controlnet = controlnet
75
  if hasattr(self, "pipe2"):
76
+ self.pipe2.controlnet = controlnet
77
  clear_cuda_and_gc()
78
 
79
+ def __load(self):
80
+ "Should not be called externally"
81
+ if self.__loaded:
82
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ if not hasattr(self, "controlnet"):
85
+ self.load_model("pose")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
+ # controlnet pipeline for tile upscaler
88
+ if get_is_sdxl():
89
+ print("Warning: Tile upscale is not supported on SDXL")
90
+
91
+ if self.__pipeline:
92
+ pipe = StableDiffusionXLControlNetPipeline(
93
+ controlnet=self.controlnet, **self.__pipeline.pipe.components
94
+ ).to("cuda")
95
+ else:
96
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
97
+ get_model_dir(),
98
+ controlnet=self.controlnet,
99
+ torch_dtype=torch.float16,
100
+ use_auth_token=get_hf_token(),
101
+ cache_dir=get_hf_cache_dir(),
102
+ use_safetensors=True,
103
+ ).to("cuda")
104
+ pipe.enable_vae_tiling()
105
+ pipe.enable_vae_slicing()
106
+ pipe.enable_xformers_memory_efficient_attention()
107
+ self.pipe2 = pipe
108
+ else:
109
+ if hasattr(self, "__pipeline"):
110
+ pipe = StableDiffusionControlNetImg2ImgPipeline(
111
+ controlnet=self.controlnet, **self.__pipeline.pipe.components
112
+ ).to("cuda")
113
+ else:
114
+ pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
115
+ get_model_dir(),
116
+ controlnet=self.controlnet,
117
+ torch_dtype=torch.float16,
118
+ use_auth_token=get_hf_token(),
119
+ cache_dir=get_hf_cache_dir(),
120
+ ).to("cuda")
121
+ # pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
122
+ pipe.enable_model_cpu_offload()
123
+ pipe.enable_xformers_memory_efficient_attention()
124
+ self.pipe = pipe
125
+
126
+ # controlnet pipeline for canny and pose
127
+ pipe2 = StableDiffusionControlNetPipeline(**pipe.components).to("cuda")
128
+ pipe2.scheduler = UniPCMultistepScheduler.from_config(
129
+ pipe2.scheduler.config
130
+ )
131
+ pipe2.enable_xformers_memory_efficient_attention()
132
+ self.pipe2 = pipe2
133
 
134
+ self.__loaded = True
 
 
 
 
135
 
136
+ def process(self, **kwargs):
137
+ if self.__current_task_name == "pose":
138
+ return self.process_pose(**kwargs)
139
+ if self.__current_task_name == "canny":
140
+ return self.process_canny(**kwargs)
141
  if self.__current_task_name == "scribble":
142
+ return self.process_scribble(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  if self.__current_task_name == "linearart":
144
+ return self.process_linearart(**kwargs)
145
+ if self.__current_task_name == "tile_upscaler":
146
+ return self.process_tile_upscaler(**kwargs)
147
+ raise Exception("ControlNet is not loaded with any model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  @torch.inference_mode()
150
  def process_canny(
 
205
  guidance_scale=guidance_scale,
206
  height=height,
207
  width=width,
 
208
  )
209
  return Result.from_result(result)
210
 
 
309
  )
310
  return Result.from_result(result)
311
 
312
+ def cleanup(self):
313
+ if hasattr(self, "pipe") and hasattr(self.pipe, "controlnet"):
314
+ del self.pipe.controlnet
315
+ if hasattr(self, "pipe2") and hasattr(self.pipe2, "controlnet"):
316
+ del self.pipe2.controlnet
317
+ if hasattr(self, "controlnet"):
318
+ del self.controlnet
319
+ self.__current_task_name = ""
320
+
321
+ clear_cuda_and_gc()
322
+
323
  def detect_pose(self, imageUrl: str) -> Image.Image:
324
  detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
325
  image = download_image(imageUrl)
 
368
  W = int(round(W / 64.0)) * 64
369
  img = input_image.resize((W, H), resample=Image.LANCZOS)
370
  return img
371
+
372
+ __model_normal = {
373
+ "pose": "lllyasviel/control_v11p_sd15_openpose",
374
+ "canny": "lllyasviel/control_v11p_sd15_canny",
375
+ "linearart": "lllyasviel/control_v11p_sd15_lineart",
376
+ "scribble": "lllyasviel/control_v11p_sd15_scribble",
377
+ "tile_upscaler": "lllyasviel/control_v11f1e_sd15_tile",
378
+ }
379
+ __model_sdxl = {
380
+ "pose": "thibaud/controlnet-openpose-sdxl-1.0",
381
+ "canny": "diffusers/controlnet-canny-sdxl-1.0",
382
+ "linearart": "canny",
383
+ "scribble": "canny",
384
+ "tile_upscaler": None,
385
+ }
internals/pipelines/high_res.py CHANGED
@@ -42,7 +42,7 @@ class HighRes(AbstractPipeline):
42
 
43
  @staticmethod
44
  def get_intermediate_dimension(target_width: int, target_height: int):
45
- def_size = 512
46
 
47
  desired_pixel_count = def_size * def_size
48
  actual_pixel_count = target_width * target_height
 
42
 
43
  @staticmethod
44
  def get_intermediate_dimension(target_width: int, target_height: int):
45
+ def_size = 1024
46
 
47
  desired_pixel_count = def_size * def_size
48
  actual_pixel_count = target_width * target_height
internals/pipelines/inpainter.py CHANGED
@@ -1,38 +1,74 @@
1
  from typing import List, Union
2
 
3
  import torch
4
- from diffusers import StableDiffusionInpaintPipeline
5
 
6
  from internals.pipelines.commons import AbstractPipeline
7
  from internals.util.commons import disable_safety_checker, download_image
8
- from internals.util.config import (get_hf_cache_dir, get_hf_token,
9
- get_inpaint_model_path)
 
 
 
 
 
10
 
11
 
12
  class InPainter(AbstractPipeline):
13
  __loaded = False
14
 
 
 
 
15
  def load(self):
16
  if self.__loaded:
17
  return
18
 
19
- self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
20
- get_inpaint_model_path(),
21
- torch_dtype=torch.float16,
22
- cache_dir=get_hf_cache_dir(),
23
- use_auth_token=get_hf_token(),
24
- ).to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  disable_safety_checker(self.pipe)
27
 
 
 
28
  self.__loaded = True
29
 
30
  def create(self, pipeline: AbstractPipeline):
31
- self.pipe = StableDiffusionInpaintPipeline(**pipeline.pipe.components).to(
32
- "cuda"
33
- )
 
 
 
 
 
34
  disable_safety_checker(self.pipe)
35
 
 
 
 
 
 
 
 
 
36
  @torch.inference_mode()
37
  def process(
38
  self,
 
1
  from typing import List, Union
2
 
3
  import torch
4
+ from diffusers import StableDiffusionInpaintPipeline, StableDiffusionXLInpaintPipeline
5
 
6
  from internals.pipelines.commons import AbstractPipeline
7
  from internals.util.commons import disable_safety_checker, download_image
8
+ from internals.util.config import (
9
+ get_hf_cache_dir,
10
+ get_hf_token,
11
+ get_is_sdxl,
12
+ get_inpaint_model_path,
13
+ get_model_dir,
14
+ )
15
 
16
 
17
  class InPainter(AbstractPipeline):
18
  __loaded = False
19
 
20
+ def init(self, pipeline: AbstractPipeline):
21
+ self.__base = pipeline
22
+
23
  def load(self):
24
  if self.__loaded:
25
  return
26
 
27
+ if hasattr(self, "__base") and get_inpaint_model_path() == get_model_dir():
28
+ self.create(self.__base)
29
+ self.__loaded = True
30
+ return
31
+
32
+ if get_is_sdxl():
33
+ self.pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
34
+ get_inpaint_model_path(),
35
+ torch_dtype=torch.float16,
36
+ cache_dir=get_hf_cache_dir(),
37
+ use_auth_token=get_hf_token(),
38
+ ).to("cuda")
39
+ else:
40
+ self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
41
+ get_inpaint_model_path(),
42
+ torch_dtype=torch.float16,
43
+ cache_dir=get_hf_cache_dir(),
44
+ use_auth_token=get_hf_token(),
45
+ ).to("cuda")
46
 
47
  disable_safety_checker(self.pipe)
48
 
49
+ self.__patch()
50
+
51
  self.__loaded = True
52
 
53
  def create(self, pipeline: AbstractPipeline):
54
+ if get_is_sdxl():
55
+ self.pipe = StableDiffusionXLInpaintPipeline(**pipeline.pipe.components).to(
56
+ "cuda"
57
+ )
58
+ else:
59
+ self.pipe = StableDiffusionInpaintPipeline(**pipeline.pipe.components).to(
60
+ "cuda"
61
+ )
62
  disable_safety_checker(self.pipe)
63
 
64
+ self.__patch()
65
+
66
+ def __patch(self):
67
+ if get_is_sdxl():
68
+ self.pipe.enable_vae_tiling()
69
+ self.pipe.enable_vae_slicing()
70
+ self.pipe.enable_xformers_memory_efficient_attention()
71
+
72
  @torch.inference_mode()
73
  def process(
74
  self,
internals/pipelines/remove_background.py CHANGED
@@ -1,15 +1,20 @@
1
  import io
2
  from pathlib import Path
3
  from typing import Union
 
 
4
 
5
  import torch
6
  import torch.nn.functional as F
7
  from PIL import Image
8
  from rembg import remove
 
9
 
10
  import internals.util.image as ImageUtil
11
  from carvekit.api.high import HiInterface
12
  from internals.util.commons import download_image, read_url
 
 
13
 
14
 
15
  class RemoveBackground:
@@ -23,6 +28,11 @@ class RemoveBackground:
23
 
24
  class RemoveBackgroundV2:
25
  def __init__(self):
 
 
 
 
 
26
  self.interface = HiInterface(
27
  object_type="object", # Can be "object" or "hairs-like".
28
  batch_size_seg=5,
@@ -36,16 +46,51 @@ class RemoveBackgroundV2:
36
  fp16=False,
37
  )
38
 
39
- def remove(self, image: Union[str, Image.Image]) -> Image.Image:
40
- img_path = Path.home() / ".cache" / "rm_bg.png"
 
41
  if type(image) is str:
42
  image = download_image(image)
43
 
44
- w, h = image.size
45
- if max(w, h) > 1536:
46
- image = ImageUtil.resize_image(image, dimension=1024)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- image.save(img_path)
49
- images_without_background = self.interface([img_path])
50
- out = images_without_background[0]
51
- return out
 
 
 
 
1
  import io
2
  from pathlib import Path
3
  from typing import Union
4
+ import numpy as np
5
+ import cv2
6
 
7
  import torch
8
  import torch.nn.functional as F
9
  from PIL import Image
10
  from rembg import remove
11
+ from internals.data.task import ModelType
12
 
13
  import internals.util.image as ImageUtil
14
  from carvekit.api.high import HiInterface
15
  from internals.util.commons import download_image, read_url
16
+ import onnxruntime as rt
17
+ import huggingface_hub
18
 
19
 
20
  class RemoveBackground:
 
28
 
29
  class RemoveBackgroundV2:
30
  def __init__(self):
31
+ model_path = huggingface_hub.hf_hub_download("skytnt/anime-seg", "isnetis.onnx")
32
+ self.anime_rembg = rt.InferenceSession(
33
+ model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
34
+ )
35
+
36
  self.interface = HiInterface(
37
  object_type="object", # Can be "object" or "hairs-like".
38
  batch_size_seg=5,
 
46
  fp16=False,
47
  )
48
 
49
+ def remove(
50
+ self, image: Union[str, Image.Image], model_type: ModelType = ModelType.REAL
51
+ ) -> Image.Image:
52
  if type(image) is str:
53
  image = download_image(image)
54
 
55
+ if model_type == ModelType.ANIME or model_type == ModelType.COMIC:
56
+ print("Using Anime Background remover")
57
+ _, img = self.__rmbg_fn(np.array(image))
58
+
59
+ return Image.fromarray(img)
60
+ else:
61
+ print("Using Real Background remover")
62
+ img_path = Path.home() / ".cache" / "rm_bg.png"
63
+
64
+ w, h = image.size
65
+ if max(w, h) > 1536:
66
+ image = ImageUtil.resize_image(image, dimension=1024)
67
+
68
+ image.save(img_path)
69
+ images_without_background = self.interface([img_path])
70
+ out = images_without_background[0]
71
+ return out
72
+
73
+ def __get_mask(self, img, s=1024):
74
+ img = (img / 255).astype(np.float32)
75
+ h, w = h0, w0 = img.shape[:-1]
76
+ h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
77
+ ph, pw = s - h, s - w
78
+ img_input = np.zeros([s, s, 3], dtype=np.float32)
79
+ img_input[ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w] = cv2.resize(
80
+ img, (w, h)
81
+ )
82
+ img_input = np.transpose(img_input, (2, 0, 1))
83
+ img_input = img_input[np.newaxis, :]
84
+ mask = self.anime_rembg.run(None, {"img": img_input})[0][0]
85
+ mask = np.transpose(mask, (1, 2, 0))
86
+ mask = mask[ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w]
87
+ mask = cv2.resize(mask, (w0, h0))[:, :, np.newaxis]
88
+ return mask
89
 
90
+ def __rmbg_fn(self, img):
91
+ mask = self.__get_mask(img)
92
+ img = (mask * img + 255 * (1 - mask)).astype(np.uint8)
93
+ mask = (mask * 255).astype(np.uint8)
94
+ img = np.concatenate([img, mask], axis=2, dtype=np.uint8)
95
+ mask = mask.repeat(3, axis=2)
96
+ return mask, img
internals/pipelines/replace_background.py CHANGED
@@ -3,10 +3,14 @@ from typing import List, Optional, Union
3
 
4
  import torch
5
  from cv2 import inpaint
6
- from diffusers import (ControlNetModel,
7
- StableDiffusionControlNetInpaintPipeline,
8
- StableDiffusionInpaintPipeline, UniPCMultistepScheduler)
 
 
 
9
  from PIL import Image, ImageFilter, ImageOps
 
10
 
11
  import internals.util.image as ImageUtil
12
  from internals.data.result import Result
@@ -17,8 +21,12 @@ from internals.pipelines.inpainter import InPainter
17
  from internals.pipelines.remove_background import RemoveBackgroundV2
18
  from internals.pipelines.upscaler import Upscaler
19
  from internals.util.commons import download_image
20
- from internals.util.config import (get_hf_cache_dir, get_hf_token,
21
- get_inpaint_model_path, get_model_dir)
 
 
 
 
22
 
23
 
24
  class ReplaceBackground(AbstractPipeline):
@@ -52,7 +60,8 @@ class ReplaceBackground(AbstractPipeline):
52
  cache_dir=get_hf_cache_dir(),
53
  use_auth_token=get_hf_token(),
54
  )
55
- pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
 
56
  pipe.to("cuda")
57
 
58
  self.pipe = pipe
@@ -87,6 +96,7 @@ class ReplaceBackground(AbstractPipeline):
87
  seed: int,
88
  steps: int,
89
  apply_high_res: bool = False,
 
90
  ):
91
  # image = Image.open("original.png")
92
  if type(image) is str:
@@ -98,7 +108,7 @@ class ReplaceBackground(AbstractPipeline):
98
  image = image.convert("RGB")
99
  if max(image.size) > 1024:
100
  image = ImageUtil.resize_image(image, dimension=1024)
101
- image = self.remove_background.remove(image)
102
 
103
  width = int(width)
104
  height = int(height)
 
3
 
4
  import torch
5
  from cv2 import inpaint
6
+ from diffusers import (
7
+ ControlNetModel,
8
+ StableDiffusionControlNetInpaintPipeline,
9
+ StableDiffusionInpaintPipeline,
10
+ UniPCMultistepScheduler,
11
+ )
12
  from PIL import Image, ImageFilter, ImageOps
13
+ from internals.data.task import ModelType
14
 
15
  import internals.util.image as ImageUtil
16
  from internals.data.result import Result
 
21
  from internals.pipelines.remove_background import RemoveBackgroundV2
22
  from internals.pipelines.upscaler import Upscaler
23
  from internals.util.commons import download_image
24
+ from internals.util.config import (
25
+ get_hf_cache_dir,
26
+ get_hf_token,
27
+ get_inpaint_model_path,
28
+ get_model_dir,
29
+ )
30
 
31
 
32
  class ReplaceBackground(AbstractPipeline):
 
60
  cache_dir=get_hf_cache_dir(),
61
  use_auth_token=get_hf_token(),
62
  )
63
+ pipe.enable_xformers_memory_efficient_attention()
64
+ pipe.enable_vae_slicing()
65
  pipe.to("cuda")
66
 
67
  self.pipe = pipe
 
96
  seed: int,
97
  steps: int,
98
  apply_high_res: bool = False,
99
+ model_type: ModelType = ModelType.REAL,
100
  ):
101
  # image = Image.open("original.png")
102
  if type(image) is str:
 
108
  image = image.convert("RGB")
109
  if max(image.size) > 1024:
110
  image = ImageUtil.resize_image(image, dimension=1024)
111
+ image = self.remove_background.remove(image, model_type=model_type)
112
 
113
  width = int(width)
114
  height = int(height)
internals/pipelines/twoStepPipeline.py CHANGED
@@ -12,7 +12,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
12
 
13
  class two_step_pipeline(StableDiffusionPipeline):
14
  @torch.no_grad()
15
- def two_step_pipeline(
16
  self,
17
  prompt: Union[str, List[str]] = None,
18
  modified_prompts: Union[str, List[str]] = None,
 
12
 
13
  class two_step_pipeline(StableDiffusionPipeline):
14
  @torch.no_grad()
15
+ def __call__(
16
  self,
17
  prompt: Union[str, List[str]] = None,
18
  modified_prompts: Union[str, List[str]] = None,
internals/util/cache.py CHANGED
@@ -1,15 +1,25 @@
1
  import gc
2
-
 
3
  import torch
4
 
5
 
 
 
 
 
 
6
  def clear_cuda_and_gc():
7
- clear_cuda()
 
8
  clear_gc()
 
 
9
 
10
 
11
  def clear_cuda():
12
- torch.cuda.empty_cache()
 
13
 
14
 
15
  def clear_gc():
 
1
  import gc
2
+ import os
3
+ import psutil
4
  import torch
5
 
6
 
7
+ def print_memory_usage():
8
+ process = psutil.Process(os.getpid())
9
+ print(f"Memory usage: {process.memory_info().rss / 1024 ** 2:2f} MB")
10
+
11
+
12
  def clear_cuda_and_gc():
13
+ print_memory_usage()
14
+ print("Clearing cuda and gc")
15
  clear_gc()
16
+ clear_cuda()
17
+ print_memory_usage()
18
 
19
 
20
  def clear_cuda():
21
+ with torch.no_grad():
22
+ torch.cuda.empty_cache()
23
 
24
 
25
  def clear_gc():
internals/util/commons.py CHANGED
@@ -150,9 +150,9 @@ def upload_image(image: Union[Image.Image, BytesIO], out_path):
150
  return image_url
151
 
152
 
153
- def download_image(url) -> Image.Image:
154
  response = requests.get(url)
155
- return Image.open(BytesIO(response.content)).convert("RGB")
156
 
157
 
158
  def download_file(url, out_path: Path):
 
150
  return image_url
151
 
152
 
153
+ def download_image(url, mode="RGB") -> Image.Image:
154
  response = requests.get(url)
155
+ return Image.open(BytesIO(response.content)).convert(mode)
156
 
157
 
158
  def download_file(url, out_path: Path):
internals/util/config.py CHANGED
@@ -61,6 +61,11 @@ def get_inpaint_model_path():
61
  return model_config.base_inpaint_model_path # pyright: ignore
62
 
63
 
 
 
 
 
 
64
  def get_root_dir():
65
  global root_dir
66
  return root_dir
 
61
  return model_config.base_inpaint_model_path # pyright: ignore
62
 
63
 
64
+ def get_is_sdxl():
65
+ global model_config
66
+ return model_config.is_sdxl # pyright: ignore
67
+
68
+
69
  def get_root_dir():
70
  global root_dir
71
  return root_dir
internals/util/lora_style.py CHANGED
@@ -10,6 +10,7 @@ from lora_diffusion import patch_pipe, tune_lora_scale
10
  from pydash import chain
11
 
12
  from internals.data.dataAccessor import getStyles
 
13
  from internals.util.commons import download_file
14
 
15
 
@@ -112,6 +113,10 @@ class LoraStyle:
112
  ) -> Union[LoraPatcher, LoraDiffuserPatcher, EmptyLoraPatcher]:
113
  "Returns a lora patcher for the given `key` and `pipe`. `pipe` can also be a list of pipes"
114
  pipe = [pipe] if not isinstance(pipe, list) else pipe
 
 
 
 
115
  if key in self.__styles:
116
  style = self.__styles[key]
117
  if style["type"] == "diffuser":
 
10
  from pydash import chain
11
 
12
  from internals.data.dataAccessor import getStyles
13
+ from internals.util.config import get_is_sdxl
14
  from internals.util.commons import download_file
15
 
16
 
 
113
  ) -> Union[LoraPatcher, LoraDiffuserPatcher, EmptyLoraPatcher]:
114
  "Returns a lora patcher for the given `key` and `pipe`. `pipe` can also be a list of pipes"
115
  pipe = [pipe] if not isinstance(pipe, list) else pipe
116
+ if get_is_sdxl():
117
+ print("Warning: Lora is not supported on SDXL")
118
+ return self.EmptyLoraPatcher(pipe)
119
+
120
  if key in self.__styles:
121
  style = self.__styles[key]
122
  if style["type"] == "diffuser":
internals/util/model_loader.py CHANGED
@@ -14,6 +14,7 @@ from tqdm import tqdm
14
  class ModelConfig:
15
  base_model_path: str
16
  base_inpaint_model_path: str
 
17
 
18
 
19
  def load_model_from_config(path):
@@ -23,9 +24,11 @@ def load_model_from_config(path):
23
  config = json.loads(f.read())
24
  model_path = config.get("model_path", path)
25
  inpaint_model_path = config.get("inpaint_model_path", path)
 
26
 
27
  m_config.base_model_path = model_path
28
  m_config.base_inpaint_model_path = inpaint_model_path
 
29
 
30
  #
31
  # if config.get("model_type") == "huggingface":
 
14
  class ModelConfig:
15
  base_model_path: str
16
  base_inpaint_model_path: str
17
+ is_sdxl: bool = False
18
 
19
 
20
  def load_model_from_config(path):
 
24
  config = json.loads(f.read())
25
  model_path = config.get("model_path", path)
26
  inpaint_model_path = config.get("inpaint_model_path", path)
27
+ is_sdxl = config.get("is_sdxl", False)
28
 
29
  m_config.base_model_path = model_path
30
  m_config.base_inpaint_model_path = inpaint_model_path
31
+ m_config.is_sdxl = is_sdxl
32
 
33
  #
34
  # if config.get("model_type") == "huggingface":
pyproject.toml CHANGED
@@ -1,4 +1,4 @@
1
  [tool.pyright]
2
- venvPath = "/Users/devel/Documents/WebProjects/creco-inference"
3
  venv = "env"
4
  exclude = ["env"]
 
1
  [tool.pyright]
2
+ venvPath = "."
3
  venv = "env"
4
  exclude = ["env"]
requirements.txt CHANGED
@@ -15,6 +15,7 @@ realesrgan==0.3.0
15
  compel==1.0.4
16
  scikit-image>=0.19.3
17
  six==1.16.0
 
18
  tifffile==2021.8.30
19
  easydict==1.9.0
20
  albumentations
@@ -32,10 +33,13 @@ xformers==0.0.21
32
  scikit-image==0.19.3
33
  omegaconf==2.3.0
34
  webdataset==0.2.48
 
35
  https://comic-assets.s3.ap-south-1.amazonaws.com/packages/mmcv_full-1.7.0-cp39-cp39-linux_x86_64.whl
36
  python-dateutil==2.8.2
37
  PyYAML
38
  invisible-watermark
39
  torchvision==0.15.2
 
 
40
  imgaug==0.4.0
41
  tqdm==4.64.1
 
15
  compel==1.0.4
16
  scikit-image>=0.19.3
17
  six==1.16.0
18
+ psutil
19
  tifffile==2021.8.30
20
  easydict==1.9.0
21
  albumentations
 
33
  scikit-image==0.19.3
34
  omegaconf==2.3.0
35
  webdataset==0.2.48
36
+ invisible-watermark
37
  https://comic-assets.s3.ap-south-1.amazonaws.com/packages/mmcv_full-1.7.0-cp39-cp39-linux_x86_64.whl
38
  python-dateutil==2.8.2
39
  PyYAML
40
  invisible-watermark
41
  torchvision==0.15.2
42
+ onnx
43
+ onnxruntime-gpu
44
  imgaug==0.4.0
45
  tqdm==4.64.1