sohojoe commited on
Commit
334dcac
·
1 Parent(s): 0b8f387

experimental ray

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+ *.pyc
experimental/clip_app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # File name: model.py
2
+ import json
3
+ import os
4
+ import numpy as np
5
+ import torch
6
+ from starlette.requests import Request
7
+ from PIL import Image
8
+ import ray
9
+ from ray import serve
10
+ from clip_retrieval.load_clip import load_clip, get_tokenizer
11
+ # from clip_retrieval.clip_client import ClipClient, Modality
12
+
13
+ # @serve.deployment(num_replicas=2, ray_actor_options={"num_cpus": 0.2, "num_gpus": 0.2})
14
+ # @serve.deployment(num_replicas=1, ray_actor_options={"num_cpus": 0.2, "num_gpus": 0.0})
15
+ @serve.deployment(num_replicas=10, ray_actor_options={"num_cpus": .2, "num_gpus": 0.0})
16
+ class CLIPTransform:
17
+ def __init__(self):
18
+ # os.environ["OMP_NUM_THREADS"] = "20"
19
+ # torch.set_num_threads(20)
20
+ # Load model
21
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
22
+ self._clip_model="ViT-L/14"
23
+ self._clip_model_id ="laion5B-L-14"
24
+ self.model, self.preprocess = load_clip(self._clip_model, use_jit=True, device=self.device)
25
+ self.tokenizer = get_tokenizer(self._clip_model)
26
+
27
+ print ("using device", self.device)
28
+
29
+ def text_to_embeddings(self, prompt):
30
+ text = self.tokenizer([prompt]).to(self.device)
31
+ with torch.no_grad():
32
+ prompt_embededdings = self.model.encode_text(text)
33
+ prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True)
34
+ return(prompt_embededdings)
35
+
36
+ def image_to_embeddings(self, input_im):
37
+ input_im = Image.fromarray(input_im)
38
+ prepro = self.preprocess(input_im).unsqueeze(0).to(self.device)
39
+ with torch.no_grad():
40
+ image_embeddings = self.model.encode_image(prepro)
41
+ image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
42
+ return(image_embeddings)
43
+
44
+ def preprocessed_image_to_emdeddings(self, prepro):
45
+ with torch.no_grad():
46
+ image_embeddings = self.model.encode_image(prepro)
47
+ image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
48
+ return(image_embeddings)
49
+
50
+ async def __call__(self, http_request: Request) -> str:
51
+ request = await http_request.json()
52
+ # print(type(request))
53
+ # print(str(request))
54
+ # switch based if we are using text or image
55
+ embeddings = None
56
+ if "text" in request:
57
+ prompt = request["text"]
58
+ embeddings = self.text_to_embeddings(prompt)
59
+ elif "image" in request:
60
+ image_url = request["image_url"]
61
+ # download image from url
62
+ import requests
63
+ from io import BytesIO
64
+ input_image = Image.open(BytesIO(image_url))
65
+ input_image = input_image.convert('RGB')
66
+ input_image = np.array(input_image)
67
+ embeddings = self.image_to_embeddings(input_image)
68
+ elif "preprocessed_image" in request:
69
+ prepro = request["preprocessed_image"]
70
+ # create torch tensor on the device
71
+ prepro = torch.tensor(prepro).to(self.device)
72
+ embeddings = self.preprocessed_image_to_emdeddings(prepro)
73
+ else:
74
+ raise Exception("Invalid request")
75
+ return embeddings.cpu().numpy().tolist()
76
+
77
+ deployment_graph = CLIPTransform.bind()
experimental/clip_app_client.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # File name: graph_client.py
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ import json
4
+ import requests
5
+ from concurrent.futures import ThreadPoolExecutor, as_completed
6
+ import time
7
+
8
+ # english_text = (
9
+ # "It was the best of times, it was the worst of times, it was the age "
10
+ # "of wisdom, it was the age of foolishness, it was the epoch of belief"
11
+ # )
12
+ # response = requests.post("http://127.0.0.1:8000/", json=english_text)
13
+ # french_text = response.text
14
+
15
+ # print(french_text)
16
+
17
+ test_image_url = "https://static.wixstatic.com/media/4d6b49_42b9435ce1104008b1b5f7a3c9bfcd69~mv2.jpg/v1/fill/w_454,h_333,fp_0.50_0.50,q_90/4d6b49_42b9435ce1104008b1b5f7a3c9bfcd69~mv2.jpg"
18
+ english_text = (
19
+ "It was the best of times, it was the worst of times, it was the age "
20
+ "of wisdom, it was the age of foolishness, it was the epoch of belief"
21
+ )
22
+
23
+
24
+ def send_text_request(number):
25
+ json = {"text": english_text}
26
+ response = requests.post("http://127.0.0.1:8000/", json=json)
27
+ embeddings = response.text
28
+ return number, embeddings
29
+
30
+ def process_text(numbers, max_workers=10):
31
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
32
+ futures = [executor.submit(send_text_request, number) for number in numbers]
33
+ for future in as_completed(futures):
34
+ n_result, result = future.result()
35
+ result = json.loads(result)
36
+ print (f"{n_result} : {len(result[0])}")
37
+
38
+ # def process_text(numbers, max_workers=10):
39
+ # for n in numbers:
40
+ # n_result, result = send_text_request(n)
41
+ # result = json.loads(result)
42
+ # print (f"{n_result} : {len(result[0])}")
43
+
44
+
45
+ if __name__ == "__main__":
46
+ # n_calls = 100000
47
+ n_calls = 1000
48
+ numbers = list(range(n_calls))
49
+ start_time = time.monotonic()
50
+ process_text(numbers)
51
+ end_time = time.monotonic()
52
+ total_time = end_time - start_time
53
+ avg_time_ms = total_time / n_calls * 1000
54
+ calls_per_sec = n_calls / total_time
55
+ print(f"Average time taken: {avg_time_ms:.2f} ms")
56
+ print(f"Number of calls per second: {calls_per_sec:.2f}")
57
+
experimental/fast_inference.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+ import ray
6
+ from ray import serve
7
+ from clip_retrieval.load_clip import load_clip, get_tokenizer
8
+ # from clip_retrieval.clip_client import ClipClient, Modality
9
+
10
+
11
+
12
+ class CLIPModel:
13
+ def __init__(self):
14
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
15
+ self._test_image_url = "https://static.wixstatic.com/media/4d6b49_42b9435ce1104008b1b5f7a3c9bfcd69~mv2.jpg/v1/fill/w_454,h_333,fp_0.50_0.50,q_90/4d6b49_42b9435ce1104008b1b5f7a3c9bfcd69~mv2.jpg"
16
+ self._clip_model="ViT-L/14"
17
+ self._clip_model_id ="laion5B-L-14"
18
+
19
+ self.model, self.preprocess = load_clip(self._clip_model, use_jit=True, device=self.device)
20
+ self.tokenizer = get_tokenizer(self._clip_model)
21
+
22
+ print ("using device", self.device)
23
+
24
+
25
+ def test_to_embeddings(self, prompt):
26
+ text = self.tokenizer([prompt]).to(self.device)
27
+ with torch.no_grad():
28
+ prompt_embededdings = self.model.encode_text(text)
29
+ prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True)
30
+ return(prompt_embededdings)
31
+
32
+ def image_to_embeddings(self, input_im):
33
+ input_im = Image.fromarray(input_im)
34
+ prepro = self.preprocess(input_im).unsqueeze(0).to(self.device)
35
+ with torch.no_grad():
36
+ image_embeddings = self.model.encode_image(prepro)
37
+ image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
38
+ return(image_embeddings)
39
+
40
+ def preprocessed_image_to_emdeddings(self, prepro):
41
+ with torch.no_grad():
42
+ image_embeddings = self.model.encode_image(prepro)
43
+ image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
44
+ return(image_embeddings)
45
+
46
+ # simple regression test
47
+ def regression_test(self):
48
+ text_embeddings = self.test_to_embeddings("Howdy!")
49
+ print("text embeddings", text_embeddings)
50
+
51
+ # download image from url
52
+ import requests
53
+ from io import BytesIO
54
+ response = requests.get(self._test_image_url)
55
+ input_image = Image.open(BytesIO(response.content))
56
+ input_image = input_image.convert('RGB')
57
+ # convert image to numpy array
58
+ input_image = np.array(input_image)
59
+ image_embeddings = self.image_to_embeddings(input_image)
60
+ print("image embeddings", image_embeddings)
61
+
62
+ input_im = Image.fromarray(input_image)
63
+ prepro = self.preprocess(input_im).unsqueeze(0).to(self.device)
64
+ image_embeddings = self.preprocessed_image_to_emdeddings(prepro)
65
+ print("image embeddings", image_embeddings)
66
+
67
+ # regression test
68
+ test_instance = CLIPModel()
69
+ test_instance.regression_test()
70
+
71
+ ray.init()
72
+ serve.start()
73
+ # Register the model with Ray Serve
74
+ serve.create_backend("clip_model", CLIPModel)
75
+ serve.create_endpoint("clip_model", backend="clip_model", route="/clip_model")
76
+
77
+
78
+ # You can now call the endpoint with your input
79
+ import requests
80
+
81
+ input_prompt = "Howdy!"
82
+ response = requests.get("http://localhost:8000/clip_model", json={"prompt": input_prompt})
83
+ print(response.json())
84
+
85
+
local_test.py → experimental/local_test.py RENAMED
File without changes