jayparmr commited on
Commit
0daeeb0
·
1 Parent(s): 19b3da3

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. .gitignore +6 -0
  2. inference.py +29 -1
  3. internals/pipelines/inpainter.py +6 -0
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ *.pyc
2
+ .ipynb_checkpoints █
3
+ env
4
+ test.py
5
+ *.jpeg
6
+ __pycache__
inference.py CHANGED
@@ -8,11 +8,12 @@ from internals.pipelines.commons import Img2Img, Text2Img
8
  from internals.pipelines.controlnets import ControlNet
9
  from internals.pipelines.img_classifier import ImageClassifier
10
  from internals.pipelines.img_to_text import Image2Text
 
11
  from internals.pipelines.prompt_modifier import PromptModifier
12
  from internals.pipelines.safety_checker import SafetyChecker
13
  from internals.util.args import apply_style_args
14
  from internals.util.avatar import Avatar
15
- from internals.util.cache import auto_clear_cuda_and_gc
16
  from internals.util.commons import pickPoses, upload_image, upload_images
17
  from internals.util.config import set_configs_from_task, set_root_dir
18
  from internals.util.failure_hander import FailureHandler
@@ -26,6 +27,7 @@ num_return_sequences = 4 # the number of results to generate
26
  auto_mode = False
27
 
28
  prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
 
29
  img2text = Image2Text()
30
  img_classifier = ImageClassifier()
31
  controlnet = ControlNet()
@@ -269,6 +271,29 @@ def img2img(task: Task):
269
  }
270
 
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  def model_fn(model_dir):
273
  print("Logs: model loaded .... starts")
274
 
@@ -288,6 +313,7 @@ def model_fn(model_dir):
288
  controlnet.load(model_dir)
289
  text2img_pipe.load(model_dir)
290
  img2img_pipe.create(text2img_pipe)
 
291
 
292
  safety_checker.apply(text2img_pipe)
293
  safety_checker.apply(img2img_pipe)
@@ -333,6 +359,8 @@ def predict_fn(data, pipe):
333
  return pose(task)
334
  elif task_type == TaskType.TILE_UPSCALE:
335
  return tile_upscale(task)
 
 
336
  else:
337
  raise Exception("Invalid task type")
338
  except Exception as e:
 
8
  from internals.pipelines.controlnets import ControlNet
9
  from internals.pipelines.img_classifier import ImageClassifier
10
  from internals.pipelines.img_to_text import Image2Text
11
+ from internals.pipelines.inpainter import InPainter
12
  from internals.pipelines.prompt_modifier import PromptModifier
13
  from internals.pipelines.safety_checker import SafetyChecker
14
  from internals.util.args import apply_style_args
15
  from internals.util.avatar import Avatar
16
+ from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
17
  from internals.util.commons import pickPoses, upload_image, upload_images
18
  from internals.util.config import set_configs_from_task, set_root_dir
19
  from internals.util.failure_hander import FailureHandler
 
27
  auto_mode = False
28
 
29
  prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
30
+ inpainter = InPainter()
31
  img2text = Image2Text()
32
  img_classifier = ImageClassifier()
33
  controlnet = ControlNet()
 
271
  }
272
 
273
 
274
+ @update_db
275
+ @slack.auto_send_alert
276
+ def inpaint(task: Task):
277
+ prompt, _ = get_patched_prompt(task)
278
+
279
+ print({"prompts": prompt})
280
+
281
+ images = inpainter.process(
282
+ prompt=prompt,
283
+ image_url=task.get_imageUrl(),
284
+ mask_image_url=task.get_maskImageUrl(),
285
+ width=task.get_width(),
286
+ height=task.get_height(),
287
+ seed=task.get_seed(),
288
+ negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
289
+ )
290
+ generated_image_urls = upload_images(images, "_inpaint", task.get_taskId())
291
+
292
+ clear_cuda()
293
+
294
+ return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
295
+
296
+
297
  def model_fn(model_dir):
298
  print("Logs: model loaded .... starts")
299
 
 
313
  controlnet.load(model_dir)
314
  text2img_pipe.load(model_dir)
315
  img2img_pipe.create(text2img_pipe)
316
+ inpainter.create(text2img_pipe)
317
 
318
  safety_checker.apply(text2img_pipe)
319
  safety_checker.apply(img2img_pipe)
 
359
  return pose(task)
360
  elif task_type == TaskType.TILE_UPSCALE:
361
  return tile_upscale(task)
362
+ elif task_type == TaskType.INPAINT:
363
+ return inpaint(task)
364
  else:
365
  raise Exception("Invalid task type")
366
  except Exception as e:
internals/pipelines/inpainter.py CHANGED
@@ -15,6 +15,12 @@ class InPainter(AbstractPipeline):
15
  ).to("cuda")
16
  disable_safety_checker(self.pipe)
17
 
 
 
 
 
 
 
18
  @torch.inference_mode()
19
  def process(
20
  self,
 
15
  ).to("cuda")
16
  disable_safety_checker(self.pipe)
17
 
18
+ def create(self, pipeline: AbstractPipeline):
19
+ self.pipe = StableDiffusionInpaintPipeline(**pipeline.pipe.components).to(
20
+ "cuda"
21
+ )
22
+ disable_safety_checker(self.pipe)
23
+
24
  @torch.inference_mode()
25
  def process(
26
  self,