Oysiyl commited on
Commit
6889e88
1 Parent(s): 6760956

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +80 -0
handler.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import base64
3
+ from PIL import Image
4
+ from io import BytesIO
5
+ from diffusers import AutoPipelineForText2Image
6
+ import torch
7
+
8
+
9
+ import numpy as np
10
+ import cv2
11
+
12
+ # set device
13
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+ if device.type != 'cuda':
15
+ raise ValueError("need to run on GPU")
16
+ # set mixed precision dtype
17
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
18
+
19
+
20
+ class EndpointHandler():
21
+ def __init__(self, path=""):
22
+ # Load StableDiffusionPipeline
23
+ self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
24
+ self.pipe = AutoPipelineForText2Image.from_pretrained(self.stable_diffusion_id,
25
+ torch_dtype=dtype,
26
+ safety_checker=None)
27
+ pipe.load_lora_weights("pytorch_lora_weights.safetensors")
28
+ pipe.enable_xformers_memory_efficient_attention()
29
+ pipe = pipe.to(device)
30
+ self.seed = 42
31
+ # Define Generator with seed
32
+ self.generator = torch.Generator(device="cpu").manual_seed(self.seed)
33
+
34
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
35
+ """
36
+ :param data: A dictionary contains `inputs`.
37
+ :return: A dictionary with `image` field contains image in base64.
38
+ """
39
+ prompt = data.pop("inputs", None)
40
+ seed = data.pop("seed", 42)
41
+
42
+ # Check if prompt is not provided
43
+ if prompt is None:
44
+ return {"error": "Please provide a prompt."}
45
+
46
+ # Check if seed changed
47
+ if seed is not None and seed != self.seed:
48
+ print(f"changing seed from {self.seed} to {seed}")
49
+ self.seed = seed
50
+ self.generator = torch.Generator(device="cpu").manual_seed(self.seed)
51
+
52
+
53
+ # hyperparamters
54
+ num_inference_steps = data.pop("num_inference_steps", 50)
55
+ guidance_scale = data.pop("guidance_scale", 7.5)
56
+ temperature = data.pop("temperature", 1.0)
57
+
58
+ # process image
59
+ image = self.decode_base64_image(image)
60
+
61
+ # run inference pipeline
62
+ out = self.pipe(
63
+ prompt=prompt,
64
+ num_inference_steps=num_inference_steps,
65
+ guidance_scale=guidance_scale,
66
+ temperature=temperature,
67
+ num_images_per_prompt=1,
68
+ generator=self.generator
69
+ )
70
+
71
+
72
+ # return first generate PIL image
73
+ return out.images[0]
74
+
75
+ # helper to decode input image
76
+ def decode_base64_image(self, image_string):
77
+ base64_image = base64.b64decode(image_string)
78
+ buffer = BytesIO(base64_image)
79
+ image = Image.open(buffer)
80
+ return image