ZhengPeng7 chrismaltais commited on
Commit
6284dc0
·
verified ·
1 Parent(s): e7c2780

simplify-inference-logic (#4)

Browse files

- Simply inference logic (2bbd3bf560c2d714643ed7eb5fb988409b83cd0a)


Co-authored-by: Chris Maltais <[email protected]>

Files changed (1) hide show
  1. app.py +37 -31
app.py CHANGED
@@ -1,14 +1,17 @@
1
  import os
2
- from glob import glob
3
  import cv2
4
  import numpy as np
5
- from PIL import Image
6
  import torch
7
- from torchvision import transforms
8
- from transformers import AutoModelForImageSegmentation
9
  import gradio as gr
10
  import spaces
 
 
 
 
 
11
  from gradio_imageslider import ImageSlider
 
 
12
 
13
  torch.set_float32_matmul_precision('high')
14
  torch.jit.script = lambda f: f
@@ -16,21 +19,21 @@ torch.jit.script = lambda f: f
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
 
19
- def array_to_pil_image(image, size=(1024, 1024)):
20
  image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR)
21
  image = Image.fromarray(image).convert('RGB')
22
  return image
23
 
24
 
25
  class ImagePreprocessor():
26
- def __init__(self, resolution=(1024, 1024)) -> None:
27
  self.transform_image = transforms.Compose([
28
  # transforms.Resize(resolution), # 1. keep consistent with the cv2.resize used in training 2. redundant with that in path_to_image()
29
  transforms.ToTensor(),
30
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
31
  ])
32
 
33
- def proc(self, image):
34
  image = self.transform_image(image)
35
  return image
36
 
@@ -45,14 +48,17 @@ usage_to_weights_file = {
45
  'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs'
46
  }
47
 
48
- from transformers import AutoModelForImageSegmentation
49
  birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
50
  birefnet.to(device)
51
  birefnet.eval()
52
 
53
 
54
  @spaces.GPU
55
- def predict(image, resolution, weights_file):
 
 
 
 
56
  global birefnet
57
  # Load BiRefNet with chosen weights
58
  _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
@@ -62,32 +68,32 @@ def predict(image, resolution, weights_file):
62
  birefnet.eval()
63
 
64
  resolution = f"{image.shape[1]}x{image.shape[0]}" if resolution == '' else resolution
65
- # Image is a RGB numpy array.
66
  resolution = [int(int(reso)//32*32) for reso in resolution.strip().split('x')]
67
- images = [image]
68
- image_shapes = [image.shape[:2] for image in images]
69
- images = [array_to_pil_image(image, resolution) for image in images]
70
 
71
- image_preprocessor = ImagePreprocessor(resolution=resolution)
72
- images_proc = []
73
- for image in images:
74
- images_proc.append(image_preprocessor.proc(image))
75
- images_proc = torch.cat([image_proc.unsqueeze(0) for image_proc in images_proc])
76
 
 
77
  with torch.no_grad():
78
- scaled_preds_tensor = birefnet(images_proc.to(device))[-1].sigmoid() # BiRefNet needs an sigmoid activation outside the forward.
79
- preds = []
80
- for image_shape, pred_tensor in zip(image_shapes, scaled_preds_tensor):
81
- if device == 'cuda':
82
- pred_tensor = pred_tensor.cpu()
83
- preds.append(torch.nn.functional.interpolate(pred_tensor.unsqueeze(0), size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy())
84
- image_preds = []
85
- for image, pred in zip(images, preds):
86
- image = image.resize(pred.shape[::-1])
87
- pred = np.repeat(np.expand_dims(pred, axis=-1), 3, axis=-1)
88
- image_preds.append((pred * image).astype(np.uint8))
89
-
90
- return image, image_preds[0]
 
91
 
92
 
93
  examples = [[_] for _ in glob('examples/*')][:]
 
1
  import os
 
2
  import cv2
3
  import numpy as np
 
4
  import torch
 
 
5
  import gradio as gr
6
  import spaces
7
+
8
+ from glob import glob
9
+ from typing import Optional, Tuple
10
+
11
+ from PIL import Image
12
  from gradio_imageslider import ImageSlider
13
+ from transformers import AutoModelForImageSegmentation
14
+ from torchvision import transforms
15
 
16
  torch.set_float32_matmul_precision('high')
17
  torch.jit.script = lambda f: f
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
 
22
+ def array_to_pil_image(image: np.ndarray, size: Tuple[int, int] = (1024, 1024)) -> Image.Image:
23
  image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR)
24
  image = Image.fromarray(image).convert('RGB')
25
  return image
26
 
27
 
28
  class ImagePreprocessor():
29
+ def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
30
  self.transform_image = transforms.Compose([
31
  # transforms.Resize(resolution), # 1. keep consistent with the cv2.resize used in training 2. redundant with that in path_to_image()
32
  transforms.ToTensor(),
33
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
34
  ])
35
 
36
+ def proc(self, image: Image.Image) -> torch.Tensor:
37
  image = self.transform_image(image)
38
  return image
39
 
 
48
  'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs'
49
  }
50
 
 
51
  birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
52
  birefnet.to(device)
53
  birefnet.eval()
54
 
55
 
56
  @spaces.GPU
57
+ def predict(
58
+ image: np.ndarray,
59
+ resolution: str,
60
+ weights_file: Optional[str]
61
+ ) -> Tuple[np.ndarray, np.ndarray]:
62
  global birefnet
63
  # Load BiRefNet with chosen weights
64
  _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
 
68
  birefnet.eval()
69
 
70
  resolution = f"{image.shape[1]}x{image.shape[0]}" if resolution == '' else resolution
 
71
  resolution = [int(int(reso)//32*32) for reso in resolution.strip().split('x')]
72
+
73
+ image_shape = image.shape[:2]
74
+ image_pil = array_to_pil_image(image, tuple(resolution))
75
 
76
+ # Preprocess the image
77
+ image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
78
+ image_proc = image_preprocessor.proc(image_pil)
79
+ image_proc = image_proc.unsqueeze(0)
 
80
 
81
+ # Perform the prediction
82
  with torch.no_grad():
83
+ scaled_pred_tensor = birefnet(image_proc.to(device))[-1].sigmoid()
84
+
85
+ if device == 'cuda':
86
+ scaled_pred_tensor = scaled_pred_tensor.cpu()
87
+
88
+ # Resize the prediction to match the original image shape
89
+ pred = torch.nn.functional.interpolate(scaled_pred_tensor, size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy()
90
+
91
+ # Apply the prediction mask to the original image
92
+ image_pil = image_pil.resize(pred.shape[::-1])
93
+ pred = np.repeat(np.expand_dims(pred, axis=-1), 3, axis=-1)
94
+ image_pred = (pred * np.array(image_pil)).astype(np.uint8)
95
+
96
+ return image, image_pred
97
 
98
 
99
  examples = [[_] for _ in glob('examples/*')][:]