jayparmr commited on
Commit
9d63ece
·
1 Parent(s): cd51d32

Upload folder using huggingface_hub

Browse files
inference.py CHANGED
@@ -14,13 +14,11 @@ from internals.pipelines.prompt_modifier import PromptModifier
14
  from internals.pipelines.safety_checker import SafetyChecker
15
  from internals.util.args import apply_style_args
16
  from internals.util.avatar import Avatar
17
- from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda, clear_cuda_and_gc
 
18
  from internals.util.commons import pickPoses, upload_image, upload_images
19
- from internals.util.config import (
20
- num_return_sequences,
21
- set_configs_from_task,
22
- set_root_dir,
23
- )
24
  from internals.util.failure_hander import FailureHandler
25
  from internals.util.lora_style import LoraStyle
26
  from internals.util.slack import Slack
@@ -295,17 +293,17 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
295
  lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
296
  lora_patcher.patch()
297
 
298
- if poses is None:
299
- if task.get_pose_coordinates():
300
- infered_pose = pose_detector.transform(
301
- image=task.get_imageUrl(),
302
- client_coordinates=task.get_pose_coordinates(),
303
- width=task.get_width(),
304
- height=task.get_height(),
305
- )
306
- poses = [infered_pose] * num_return_sequences
307
- else:
308
- poses = [controlnet.detect_pose(task.get_imageUrl())] * num_return_sequences
309
 
310
  images, has_nsfw = controlnet.process_pose(
311
  prompt=prompt,
 
14
  from internals.pipelines.safety_checker import SafetyChecker
15
  from internals.util.args import apply_style_args
16
  from internals.util.avatar import Avatar
17
+ from internals.util.cache import (auto_clear_cuda_and_gc, clear_cuda,
18
+ clear_cuda_and_gc)
19
  from internals.util.commons import pickPoses, upload_image, upload_images
20
+ from internals.util.config import (num_return_sequences, set_configs_from_task,
21
+ set_root_dir)
 
 
 
22
  from internals.util.failure_hander import FailureHandler
23
  from internals.util.lora_style import LoraStyle
24
  from internals.util.slack import Slack
 
293
  lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
294
  lora_patcher.patch()
295
 
296
+ try:
297
+ infered_pose = pose_detector.transform(
298
+ image=task.get_imageUrl(),
299
+ client_coordinates=task.get_pose_coordinates(),
300
+ width=task.get_width(),
301
+ height=task.get_height(),
302
+ )
303
+ poses = [infered_pose] * num_return_sequences
304
+ except Exception as e:
305
+ print("Failed to detect pose, using Open Pose detector", e)
306
+ poses = [controlnet.detect_pose(task.get_imageUrl())] * num_return_sequences
307
 
308
  images, has_nsfw = controlnet.process_pose(
309
  prompt=prompt,
internals/pipelines/pose_detector.py CHANGED
@@ -4,7 +4,7 @@ from typing import Optional, Union
4
  from PIL import Image, ImageDraw
5
  from torch import ge
6
 
7
- from internals.util.commons import download_file, download_image
8
  from internals.util.config import get_root_dir
9
  from models.pose.body import Body
10
 
@@ -77,16 +77,18 @@ class PoseDetector:
77
  image = Image.new("RGB", (width, height), "black")
78
  draw = ImageDraw.Draw(image)
79
 
80
- points = data["candidate"]
81
  for pair in self.__pose_logical_map:
82
- xy = points[pair[0] - 1]
83
- x1y1 = points[pair[1] - 1]
 
 
 
 
 
 
 
84
 
85
- draw.line(
86
- (xy[0], xy[1], x1y1[0], x1y1[1]),
87
- fill=pair[2],
88
- width=4,
89
- )
90
  for i, point in enumerate(points):
91
  x = point[0]
92
  y = point[1]
@@ -99,7 +101,7 @@ class PoseDetector:
99
  subset = []
100
 
101
  if type(image) == str:
102
- image = download_image(imageUrl)
103
 
104
  image = image.resize((width, height))
105
 
 
4
  from PIL import Image, ImageDraw
5
  from torch import ge
6
 
7
+ from internals.util.commons import download_file, download_image, safe_index
8
  from internals.util.config import get_root_dir
9
  from models.pose.body import Body
10
 
 
77
  image = Image.new("RGB", (width, height), "black")
78
  draw = ImageDraw.Draw(image)
79
 
80
+ points: list = data["candidate"]
81
  for pair in self.__pose_logical_map:
82
+ xy = safe_index(points, pair[0] - 1)
83
+ x1y1 = safe_index(points, pair[1] - 1)
84
+
85
+ if xy and x1y1:
86
+ draw.line(
87
+ (xy[0], xy[1], x1y1[0], x1y1[1]),
88
+ fill=pair[2],
89
+ width=4,
90
+ )
91
 
 
 
 
 
 
92
  for i, point in enumerate(points):
93
  x = point[0]
94
  y = point[1]
 
101
  subset = []
102
 
103
  if type(image) == str:
104
+ image = download_image(image)
105
 
106
  image = image.resize((width, height))
107
 
internals/util/commons.py CHANGED
@@ -5,7 +5,7 @@ import random
5
  import re
6
  from io import BytesIO
7
  from pathlib import Path
8
- from typing import Union
9
 
10
  import boto3
11
  import requests
@@ -191,6 +191,14 @@ def construct_default_s3_url(key):
191
  return "https://comic-assets.s3.ap-south-1.amazonaws.com/" + key
192
 
193
 
 
 
 
 
 
 
 
 
194
  def read_url(url: str):
195
  with urllib.request.urlopen(url) as u:
196
  return u.read()
 
5
  import re
6
  from io import BytesIO
7
  from pathlib import Path
8
+ from typing import Optional, Union
9
 
10
  import boto3
11
  import requests
 
191
  return "https://comic-assets.s3.ap-south-1.amazonaws.com/" + key
192
 
193
 
194
+ def safe_index(array, index) -> Optional:
195
+ if index < 0:
196
+ return None
197
+ if index >= len(array):
198
+ return None
199
+ return array[index]
200
+
201
+
202
  def read_url(url: str):
203
  with urllib.request.urlopen(url) as u:
204
  return u.read()