Upload folder using huggingface_hub
Browse files- inference.py +15 -17
- internals/pipelines/pose_detector.py +12 -10
- internals/util/commons.py +9 -1
inference.py
CHANGED
@@ -14,13 +14,11 @@ from internals.pipelines.prompt_modifier import PromptModifier
|
|
14 |
from internals.pipelines.safety_checker import SafetyChecker
|
15 |
from internals.util.args import apply_style_args
|
16 |
from internals.util.avatar import Avatar
|
17 |
-
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda,
|
|
|
18 |
from internals.util.commons import pickPoses, upload_image, upload_images
|
19 |
-
from internals.util.config import (
|
20 |
-
|
21 |
-
set_configs_from_task,
|
22 |
-
set_root_dir,
|
23 |
-
)
|
24 |
from internals.util.failure_hander import FailureHandler
|
25 |
from internals.util.lora_style import LoraStyle
|
26 |
from internals.util.slack import Slack
|
@@ -295,17 +293,17 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
|
|
295 |
lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
|
296 |
lora_patcher.patch()
|
297 |
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
|
310 |
images, has_nsfw = controlnet.process_pose(
|
311 |
prompt=prompt,
|
|
|
14 |
from internals.pipelines.safety_checker import SafetyChecker
|
15 |
from internals.util.args import apply_style_args
|
16 |
from internals.util.avatar import Avatar
|
17 |
+
from internals.util.cache import (auto_clear_cuda_and_gc, clear_cuda,
|
18 |
+
clear_cuda_and_gc)
|
19 |
from internals.util.commons import pickPoses, upload_image, upload_images
|
20 |
+
from internals.util.config import (num_return_sequences, set_configs_from_task,
|
21 |
+
set_root_dir)
|
|
|
|
|
|
|
22 |
from internals.util.failure_hander import FailureHandler
|
23 |
from internals.util.lora_style import LoraStyle
|
24 |
from internals.util.slack import Slack
|
|
|
293 |
lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
|
294 |
lora_patcher.patch()
|
295 |
|
296 |
+
try:
|
297 |
+
infered_pose = pose_detector.transform(
|
298 |
+
image=task.get_imageUrl(),
|
299 |
+
client_coordinates=task.get_pose_coordinates(),
|
300 |
+
width=task.get_width(),
|
301 |
+
height=task.get_height(),
|
302 |
+
)
|
303 |
+
poses = [infered_pose] * num_return_sequences
|
304 |
+
except Exception as e:
|
305 |
+
print("Failed to detect pose, using Open Pose detector", e)
|
306 |
+
poses = [controlnet.detect_pose(task.get_imageUrl())] * num_return_sequences
|
307 |
|
308 |
images, has_nsfw = controlnet.process_pose(
|
309 |
prompt=prompt,
|
internals/pipelines/pose_detector.py
CHANGED
@@ -4,7 +4,7 @@ from typing import Optional, Union
|
|
4 |
from PIL import Image, ImageDraw
|
5 |
from torch import ge
|
6 |
|
7 |
-
from internals.util.commons import download_file, download_image
|
8 |
from internals.util.config import get_root_dir
|
9 |
from models.pose.body import Body
|
10 |
|
@@ -77,16 +77,18 @@ class PoseDetector:
|
|
77 |
image = Image.new("RGB", (width, height), "black")
|
78 |
draw = ImageDraw.Draw(image)
|
79 |
|
80 |
-
points = data["candidate"]
|
81 |
for pair in self.__pose_logical_map:
|
82 |
-
xy = points
|
83 |
-
x1y1 = points
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
-
draw.line(
|
86 |
-
(xy[0], xy[1], x1y1[0], x1y1[1]),
|
87 |
-
fill=pair[2],
|
88 |
-
width=4,
|
89 |
-
)
|
90 |
for i, point in enumerate(points):
|
91 |
x = point[0]
|
92 |
y = point[1]
|
@@ -99,7 +101,7 @@ class PoseDetector:
|
|
99 |
subset = []
|
100 |
|
101 |
if type(image) == str:
|
102 |
-
image = download_image(
|
103 |
|
104 |
image = image.resize((width, height))
|
105 |
|
|
|
4 |
from PIL import Image, ImageDraw
|
5 |
from torch import ge
|
6 |
|
7 |
+
from internals.util.commons import download_file, download_image, safe_index
|
8 |
from internals.util.config import get_root_dir
|
9 |
from models.pose.body import Body
|
10 |
|
|
|
77 |
image = Image.new("RGB", (width, height), "black")
|
78 |
draw = ImageDraw.Draw(image)
|
79 |
|
80 |
+
points: list = data["candidate"]
|
81 |
for pair in self.__pose_logical_map:
|
82 |
+
xy = safe_index(points, pair[0] - 1)
|
83 |
+
x1y1 = safe_index(points, pair[1] - 1)
|
84 |
+
|
85 |
+
if xy and x1y1:
|
86 |
+
draw.line(
|
87 |
+
(xy[0], xy[1], x1y1[0], x1y1[1]),
|
88 |
+
fill=pair[2],
|
89 |
+
width=4,
|
90 |
+
)
|
91 |
|
|
|
|
|
|
|
|
|
|
|
92 |
for i, point in enumerate(points):
|
93 |
x = point[0]
|
94 |
y = point[1]
|
|
|
101 |
subset = []
|
102 |
|
103 |
if type(image) == str:
|
104 |
+
image = download_image(image)
|
105 |
|
106 |
image = image.resize((width, height))
|
107 |
|
internals/util/commons.py
CHANGED
@@ -5,7 +5,7 @@ import random
|
|
5 |
import re
|
6 |
from io import BytesIO
|
7 |
from pathlib import Path
|
8 |
-
from typing import Union
|
9 |
|
10 |
import boto3
|
11 |
import requests
|
@@ -191,6 +191,14 @@ def construct_default_s3_url(key):
|
|
191 |
return "https://comic-assets.s3.ap-south-1.amazonaws.com/" + key
|
192 |
|
193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
def read_url(url: str):
|
195 |
with urllib.request.urlopen(url) as u:
|
196 |
return u.read()
|
|
|
5 |
import re
|
6 |
from io import BytesIO
|
7 |
from pathlib import Path
|
8 |
+
from typing import Optional, Union
|
9 |
|
10 |
import boto3
|
11 |
import requests
|
|
|
191 |
return "https://comic-assets.s3.ap-south-1.amazonaws.com/" + key
|
192 |
|
193 |
|
194 |
+
def safe_index(array, index) -> Optional:
|
195 |
+
if index < 0:
|
196 |
+
return None
|
197 |
+
if index >= len(array):
|
198 |
+
return None
|
199 |
+
return array[index]
|
200 |
+
|
201 |
+
|
202 |
def read_url(url: str):
|
203 |
with urllib.request.urlopen(url) as u:
|
204 |
return u.read()
|