hlky HF staff commited on
Commit
bfe4ff1
·
verified ·
1 Parent(s): 6691b0e

new handler

Browse files
Files changed (1) hide show
  1. handler.py +46 -60
handler.py CHANGED
@@ -1,86 +1,72 @@
1
- from typing import Dict, List, Any
 
 
2
  import torch
3
- from base64 import b64decode
4
- from huggingface_hub import model_info
5
  from diffusers import AutoencoderKL
6
  from diffusers.image_processor import VaeImageProcessor
7
 
 
8
  class EndpointHandler:
9
  def __init__(self, path=""):
10
  self.device = "cuda"
11
  self.dtype = torch.float16
12
- self.vae = AutoencoderKL.from_pretrained(path, torch_dtype=self.dtype).to(self.device, self.dtype).eval()
13
 
14
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
15
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
16
 
17
  @torch.no_grad()
18
- def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
19
  """
20
  Args:
21
  data (:obj:):
22
  includes the input data and the parameters for the inference.
23
  """
24
- tensor = data["inputs"]
25
- tensor = b64decode(tensor.encode("utf-8"))
26
- parameters = data.get("parameters", {})
27
- if "shape" not in parameters:
28
- raise ValueError("Expected `shape` in parameters.")
29
- if "dtype" not in parameters:
30
- raise ValueError("Expected `dtype` in parameters.")
31
-
32
- DTYPE_MAP = {
33
- "float16": torch.float16,
34
- "float32": torch.float32,
35
- "bfloat16": torch.bfloat16,
36
- }
37
 
38
- shape = parameters.get("shape")
39
- dtype = DTYPE_MAP.get(parameters.get("dtype"))
40
- tensor = torch.frombuffer(bytearray(tensor), dtype=dtype).reshape(shape)
41
 
42
- needs_upcasting = (
43
- self.vae.dtype == torch.float16 and self.vae.config.force_upcast
44
- )
45
- if needs_upcasting:
46
- self.vae = self.vae.to(torch.float32)
47
- tensor = tensor.to(self.device, torch.float32)
48
- else:
49
- tensor = tensor.to(self.device, self.dtype)
50
-
51
- # unscale/denormalize the latents
52
- # denormalize with the mean and std if available and not None
53
- has_latents_mean = (
54
- hasattr(self.vae.config, "latents_mean")
55
- and self.vae.config.latents_mean is not None
56
- )
57
- has_latents_std = (
58
- hasattr(self.vae.config, "latents_std")
59
- and self.vae.config.latents_std is not None
60
- )
61
- if has_latents_mean and has_latents_std:
62
- latents_mean = (
63
- torch.tensor(self.vae.config.latents_mean)
64
- .view(1, 4, 1, 1)
65
- .to(tensor.device, tensor.dtype)
66
- )
67
- latents_std = (
68
- torch.tensor(self.vae.config.latents_std)
69
- .view(1, 4, 1, 1)
70
- .to(tensor.device, tensor.dtype)
71
  )
72
- tensor = (
73
- tensor * latents_std / self.vae.config.scaling_factor + latents_mean
 
74
  )
75
- else:
76
- tensor = tensor / self.vae.config.scaling_factor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  with torch.no_grad():
79
- image = self.vae.decode(tensor, return_dict=False)[0]
80
-
81
- if needs_upcasting:
82
- self.vae.to(dtype=torch.float16)
83
 
84
- image = self.image_processor.postprocess(image, output_type="pil")
 
 
 
 
 
85
 
86
- return image[0]
 
1
+ from typing import cast, Union
2
+
3
+ import PIL.Image
4
  import torch
5
+
 
6
  from diffusers import AutoencoderKL
7
  from diffusers.image_processor import VaeImageProcessor
8
 
9
+
10
  class EndpointHandler:
11
  def __init__(self, path=""):
12
  self.device = "cuda"
13
  self.dtype = torch.float16
14
+ self.vae = cast(AutoencoderKL, AutoencoderKL.from_pretrained(path, torch_dtype=self.dtype).to(self.device, self.dtype).eval())
15
 
16
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
17
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
18
 
19
  @torch.no_grad()
20
+ def __call__(self, data) -> Union[torch.Tensor, PIL.Image.Image]:
21
  """
22
  Args:
23
  data (:obj:):
24
  includes the input data and the parameters for the inference.
25
  """
26
+ tensor = cast(torch.Tensor, data["inputs"])
27
+ parameters = cast(dict, data.get("parameters", {}))
28
+ do_scaling = cast(bool, parameters.get("do_scaling", True))
29
+ output_type = cast(str, parameters.get("output_type", "pil"))
30
+ partial_postprocess = cast(bool, parameters.get("partial_postprocess", False))
31
+ if partial_postprocess and output_type != "pt":
32
+ output_type = "pt"
 
 
 
 
 
 
33
 
34
+ tensor = tensor.to(self.device, self.dtype)
 
 
35
 
36
+ if do_scaling:
37
+ has_latents_mean = (
38
+ hasattr(self.vae.config, "latents_mean")
39
+ and self.vae.config.latents_mean is not None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  )
41
+ has_latents_std = (
42
+ hasattr(self.vae.config, "latents_std")
43
+ and self.vae.config.latents_std is not None
44
  )
45
+ if has_latents_mean and has_latents_std:
46
+ latents_mean = (
47
+ torch.tensor(self.vae.config.latents_mean)
48
+ .view(1, 4, 1, 1)
49
+ .to(tensor.device, tensor.dtype)
50
+ )
51
+ latents_std = (
52
+ torch.tensor(self.vae.config.latents_std)
53
+ .view(1, 4, 1, 1)
54
+ .to(tensor.device, tensor.dtype)
55
+ )
56
+ tensor = (
57
+ tensor * latents_std / self.vae.config.scaling_factor + latents_mean
58
+ )
59
+ else:
60
+ tensor = tensor / self.vae.config.scaling_factor
61
 
62
  with torch.no_grad():
63
+ image = cast(torch.Tensor, self.vae.decode(tensor, return_dict=False)[0])
 
 
 
64
 
65
+ if partial_postprocess:
66
+ image = (image * 0.5 + 0.5).clamp(0, 1)
67
+ image = image.permute(0, 2, 3, 1).contiguous().float()
68
+ image = (image * 255).round().to(torch.uint8)
69
+ elif output_type == "pil":
70
+ image = cast(PIL.Image.Image, self.image_processor.postprocess(image, output_type="pil")[0])
71
 
72
+ return image