lastdefiance20 commited on
Commit
d942a8d
·
1 Parent(s): 3dcb152

First commit

Browse files
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import numpy as np
3
+ import gradio as gr
4
+ import matplotlib.pyplot as plt
5
+
6
+ from io import BytesIO
7
+ from PIL import Image
8
+ from canvas import Idefics2Pipeline
9
+
10
+
11
+ def run_canvas(front_view, map_view, prompt):
12
+ pipeline = Idefics2Pipeline.from_pretrained(
13
+ "maum-ai/CANVAS-S"
14
+ )
15
+ messages = [
16
+ {"role": "system", "content": [{"type": "text", "text": prompt}]},
17
+ {
18
+ "role": "user",
19
+ "content": [{"type": "image"}, {"type": "image"}],
20
+ },
21
+ ]
22
+
23
+ print(front_view)
24
+
25
+ images = [front_view, map_view]
26
+ pred = pipeline([messages], [images], return_traj=False)
27
+ pred_action = re.findall(r"<ACTION_(\d+)>", pred[0])
28
+ pred_action = np.array(pred_action, dtype=np.int64)
29
+ print(pred_action)
30
+ pred_action_odom = pipeline.action_tokenizer.detokenize(pred_action).tolist()
31
+ print(pred_action_odom)
32
+
33
+ # Create a figure and axes
34
+ fig, axes = plt.subplots(1, 1, figsize=(8, 6))
35
+
36
+ # Scale factor for the arrow
37
+ scale_factor = 0.2
38
+
39
+ axes.plot(0, 0, marker="o", color="black", markersize=10)
40
+ axes.invert_xaxis()
41
+
42
+ for i, center in zip(pred_action, pred_action_odom):
43
+ x, y, yaw = center
44
+ axes.plot(y, x, marker="^", color="blue")
45
+ axes.arrow(
46
+ y,
47
+ x,
48
+ np.sin(yaw) * scale_factor,
49
+ np.cos(yaw) * scale_factor,
50
+ head_width=scale_factor * 0.3,
51
+ head_length=scale_factor * 0.3,
52
+ fc="k",
53
+ ec="k",
54
+ )
55
+ axes.text(y, x, f"{i}", fontsize=10)
56
+ axes.axis("equal")
57
+ axes.grid(True)
58
+
59
+ buf = BytesIO()
60
+ fig.savefig(buf, format="png")
61
+ buf.seek(0) # Rewind the buffer to the beginning
62
+ pil_img = Image.open(buf)
63
+
64
+ return pil_img
65
+
66
+ examples = [
67
+ ["src/office/0a1f277a93fed629365ac5863c20c64e_frontview_6.0.png", "src/office/0a1f277a93fed629365ac5863c20c64e_map_6.0.png", """You are an indoor food-serving robot.
68
+
69
+ You must follow these driving instructions:
70
+ 1. You must avoid collisions.
71
+ 2. You should prioritize reaching the final destination.
72
+ 3. You should follow the Trajectory Instruction.
73
+ a. If the Trajectory Instruction cannot be followed due to any obstacles, you should deviate to bypass the obstacle.
74
+ b. You should try to evade any identifiable obstacles.
75
+ 4. You should maintain a constant driving speed.
76
+ a. Indoors, you should drive at a speed of 3-4km/h.
77
+ 5. You must slow down(2km/h or lower) if a human or obstacle comes within 1.5m radius.
78
+ a. You must slow down(2km/h or lower) in areas where a human could suddenly appear from a blind spot."""],
79
+ ["src/orchard/d578264e1e51cc5b8f0e496ab381cee4_frontview_79.0.png", "src/orchard/d578264e1e51cc5b8f0e496ab381cee4_map_79.0.png", """You are an outdoor speed-sprayer robot.
80
+
81
+ You must follow these driving instructions:
82
+ 1. You must avoid collisions.
83
+ 2. You should prioritize reaching the final destination.
84
+ 3. You should follow the Trajectory Instruction.
85
+ a. If the Trajectory Instruction cannot be followed due to any obstacles, you should deviate to bypass the obstacle.
86
+ b. You should try to evade any identifiable obstacles.
87
+ 4. You should maintain a constant driving speed."""],
88
+ ["src/sidewalk/2d0dde2a98083b7d60b24651d37532dc_frontview_4.0.png", "src/sidewalk/2d0dde2a98083b7d60b24651d37532dc_map_4.0.png", """You are an outdoor last mile delivery robot.
89
+
90
+ You must follow these driving instructions:
91
+ 1. You must avoid collisions.
92
+ 2. You should prioritize reaching the final destination.
93
+ 3. You should follow the Trajectory Instruction.
94
+ a. If the Trajectory Instruction cannot be followed due to any obstacles, you should deviate to bypass the obstacle.
95
+ b. You should try to evade any identifiable obstacles.
96
+ 4. You should maintain a constant driving speed.
97
+ 5. You must drive on the sidewalk.
98
+ a. If you need to cross the road, you must use the crosswalk."""],
99
+ ["src/road/3cfce98ab33a3dc8d43584d5a7039cf5_frontview_8.75.png", "src/road/3cfce98ab33a3dc8d43584d5a7039cf5_map_8.75.png", """You are an outdoor self-driving robot taxi.
100
+
101
+ You must follow these driving instructions:
102
+ 1. You must avoid collisions.
103
+ 2. You should prioritize reaching the final destination.
104
+ 3. You should follow the Trajectory Instruction.
105
+ a. If the Trajectory Instruction cannot be followed due to any obstacles, you should deviate to bypass the obstacle.
106
+ b. You should try to evade any identifiable obstacles.
107
+ 4. You should maintain a constant driving speed.
108
+ 5. You must drive on the road.
109
+ a. You should drive according to the left-hand-traffic law.
110
+ 6. You should slow down before entering intersections, speed bumps, and crosswalks."""],
111
+ ]
112
+
113
+ demo = gr.Interface(
114
+ fn = run_canvas,
115
+ inputs = [
116
+ gr.Image(label="front_view", type="pil"),
117
+ gr.Image(label="map_view", type="pil"),
118
+ gr.Textbox(label="prompt")
119
+ ],
120
+ outputs = gr.Image(label="generated waypoint"),
121
+ title="CANVAS Demo",
122
+ description="This is the demo of the CANVAS-S model from CANVAS: Commonsense-Aware Navigation System for Intuitive Human-Robot Interaction",
123
+ examples=examples
124
+ )
125
+
126
+ demo.launch()
canvas.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import Idefics2ForConditionalGeneration, Idefics2Processor, PreTrainedModel, ProcessorMixin
6
+ from typing import Optional
7
+ import re
8
+
9
+ import logging
10
+ import pickle
11
+ from pathlib import Path
12
+
13
+ import numpy as np
14
+ from matplotlib import pyplot as plt
15
+ from sklearn.cluster import KMeans
16
+
17
+ class BaseModelYamlJsonMixin:
18
+ """
19
+ BaseModel with helper methods for loading and saving to yaml/json format.
20
+ """
21
+
22
+ @classmethod
23
+ def from_yaml(cls, path: Path):
24
+ with open(path, "r", encoding="utf-8") as f:
25
+ return cls(**yaml.safe_load(f))
26
+
27
+ def to_yaml(self: BaseModel, path: Path):
28
+ with open(path, "w", encoding="utf-8") as f:
29
+ yaml.safe_dump(self.model_dump(), f)
30
+
31
+ @classmethod
32
+ def from_json(cls, path: Path):
33
+ with open(path, "r", encoding="utf-8") as f:
34
+ return cls.model_validate_json(f.read())
35
+
36
+ def to_json(self: BaseModel, path: Path, indent: int = 4, *args, **kwargs):
37
+ with open(path, "w", encoding="utf-8") as f:
38
+ f.write(self.model_dump_json(indent=indent, *args, **kwargs))
39
+
40
+ class BaseModelWithYamlJsonFromTo(BaseModel, BaseModelYamlJsonMixin):
41
+ pass
42
+
43
+ class Idefics2TrainAdditionalConfig(BaseModel):
44
+ """
45
+ num_action_tokens (`int`, defaults to `32`):
46
+ Number of action tokens to add to the tokenizer vocabulary.
47
+ do_image_splitting (`bool`, *optional*, defaults to `False`):
48
+ Whether to split the image into a sequence 4 equal sub-images concatenated with the original image. That
49
+ strategy was first introduced in https://arxiv.org/abs/2311.06607.
50
+ lora_config (`dict`, defaults to recommended config from https://x.com/danielhanchen/status/1791900967472140583):
51
+ Configuration for the LoRA model. If it is `None`, the model will not use LoRA.
52
+ """
53
+
54
+ # must be set to extend vocabulary of model + tokenizer
55
+ num_action_tokens: int = -1 # it will be overwritten by the processor_config.yml
56
+ # must be set to be used in pipeline
57
+ num_actions: int = -1 # it will be overwritten by the processor_config.yml
58
+
59
+ do_image_splitting: bool = True
60
+ freeze_original_vocab: bool = False
61
+ freeze_vision_model: bool = False
62
+ freeze_connector: bool = False
63
+ torch_dtype: str = "bfloat16"
64
+ lora_config: dict | None = dict(
65
+ r=256,
66
+ lora_alpha=512,
67
+ lora_dropout=0.1,
68
+ target_modules="all-linear",
69
+ use_rslora=True,
70
+ init_lora_weights="gaussian",
71
+ modules_to_save=["lm_head", "embed_tokens"],
72
+ )
73
+ model_name_or_path: str = "HuggingFaceM4/idefics2-8b"
74
+
75
+
76
+ class KMeansActionTokenizer():
77
+ def __init__(self, action_count: int = 128):
78
+ self.action_count = action_count
79
+ self.kmeans = KMeans(n_clusters=self.action_count, random_state=np.random.RandomState(seed=42))
80
+
81
+ @property
82
+ def token_count(self):
83
+ return self.action_count
84
+
85
+ @classmethod
86
+ def from_pretrained(cls, model_path: str | Path):
87
+ model_path = Path(model_path)
88
+ self = cls()
89
+ with open(model_path / "tokenizer.pkl", "rb") as file:
90
+ self.kmeans = pickle.load(file)
91
+ self.action_count = self.kmeans.n_clusters
92
+ # assert self.action_count == 32
93
+ return self
94
+
95
+ def save_pretrained(self, model_path: str | Path):
96
+ model_path = Path(model_path)
97
+ model_path.mkdir(exist_ok=True)
98
+ with open(model_path / "tokenizer.pkl", "wb") as file:
99
+ pickle.dump(self.kmeans, file)
100
+
101
+ def train(self, actions):
102
+ self.kmeans.fit(actions)
103
+
104
+ def tokenize(self, action, padding=False, max_length=-1, truncation=False):
105
+ # action: (K, 3) shape, adjusted delta_position and delta_yaw
106
+ return [i for i in self.kmeans.predict(action)]
107
+
108
+ def detokenize(self, tokens):
109
+ # Token Check
110
+ check = np.asarray(tokens)
111
+ in_valid_range = (0 <= check) & (check < self.action_count)
112
+ if not in_valid_range.all():
113
+ logging.warning(f"Invalid tokens occur: {tokens}")
114
+ # If error occurs, return stop action.
115
+ return np.asarray([[0.0, 0.0, 0.0] for _ in range(len(tokens))])
116
+ return np.asarray([self.kmeans.cluster_centers_[t] for t in tokens])
117
+
118
+ def visualize(self, figset=None):
119
+ if figset is None:
120
+ fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(12, 16), dpi=300)
121
+ else:
122
+ fig, axes = figset
123
+ FONT = {"fontsize": 20}
124
+
125
+ axes[0].set_title("Center", fontdict=FONT)
126
+ axes[1].set_title("Center_Rot", fontdict=FONT)
127
+
128
+ labels = self.kmeans.labels_
129
+ centers = self.kmeans.cluster_centers_
130
+
131
+ # plot center. each center is given as (x, y, yaw). plot point (x,y) and arrow from (x,y) to p', with direction of yaw. consider (x, y)'s scale
132
+ scale_factor = 0.05
133
+ for i, center in enumerate(centers):
134
+ x, y, yaw = center
135
+ axes[0].plot(x, y, "ro")
136
+ axes[0].arrow(
137
+ x,
138
+ y,
139
+ np.cos(yaw) * scale_factor,
140
+ np.sin(yaw) * scale_factor,
141
+ head_width=scale_factor * 0.3,
142
+ head_length=scale_factor * 0.3,
143
+ fc="k",
144
+ ec="k",
145
+ )
146
+ axes[0].text(x, y, f"{i}", fontsize=10)
147
+ axes[0].axis("equal")
148
+ axes[0].grid(True)
149
+
150
+ # filter centers that are not far from origin in distance 0.3
151
+ _centers = centers[np.linalg.norm(centers[:, :2], axis=1) < 0.05]
152
+ # print(f"action near zero: {_centers}")
153
+ scale_factor = 0.1
154
+ for center in _centers:
155
+ x, y, yaw = center
156
+ axes[1].plot(x, y, "ro")
157
+ axes[1].arrow(
158
+ x,
159
+ y,
160
+ np.cos(yaw) * scale_factor,
161
+ np.sin(yaw) * scale_factor,
162
+ head_width=scale_factor * 0.3,
163
+ head_length=scale_factor * 0.3,
164
+ fc="k",
165
+ ec="k",
166
+ )
167
+ axes[1].axis("equal")
168
+ axes[1].grid(True)
169
+
170
+ return fig, axes
171
+
172
+
173
+ class Idefics2PipelineConfig(BaseModelWithYamlJsonFromTo):
174
+ pipeline_class: str = "Idefics2Pipeline"
175
+ train_additional_cfg: Idefics2TrainAdditionalConfig
176
+
177
+
178
+ class Idefics2Pipeline():
179
+ def __init__(
180
+ self,
181
+ model: PreTrainedModel,
182
+ processor: ProcessorMixin,
183
+ action_tokenizer: KMeansActionTokenizer,
184
+ config: Idefics2PipelineConfig,
185
+ ):
186
+ self.model = model
187
+ self.processor = processor
188
+ self.action_tokenizer = action_tokenizer
189
+ self.config = config
190
+
191
+ def save_pretrained(
192
+ self,
193
+ save_directory: str,
194
+ ):
195
+ if not isinstance(save_directory, Path):
196
+ save_directory = Path(save_directory)
197
+ self.model.save_pretrained(save_directory)
198
+ self.processor.save_pretrained(save_directory)
199
+ self.action_tokenizer.save_pretrained(save_directory)
200
+ self.config.to_json(f"{save_directory}/pipeline_config.json")
201
+
202
+ @classmethod
203
+ def from_pretrained(cls, pretrained_model_name_or_path: str):
204
+ if not isinstance(pretrained_model_name_or_path, Path):
205
+ pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
206
+
207
+ config = Idefics2PipelineConfig.model_validate_json(
208
+ (pretrained_model_name_or_path / "pipeline_config.json").read_text()
209
+ )
210
+ model = Idefics2ForConditionalGeneration.from_pretrained(pretrained_model_name_or_path)
211
+ processor = Idefics2Processor.from_pretrained(pretrained_model_name_or_path)
212
+ model.eval()
213
+ action_tokenizer = KMeansActionTokenizer.from_pretrained(pretrained_model_name_or_path)
214
+ return cls(model, processor, action_tokenizer, config)
215
+
216
+ def to(self, device):
217
+ return self.model.to(device)
218
+
219
+ @torch.no_grad()
220
+ def __call__(
221
+ self,
222
+ examples: list[dict],
223
+ return_traj: Optional[bool] = False,
224
+ ):
225
+ """
226
+ call model with examples
227
+
228
+ Args:
229
+ examples: list of example, [B, example]
230
+ return_traj: return trajectory if True
231
+ """
232
+
233
+ raise NotImplementedError("Not implemented yet")
234
+
235
+ # same as idefics2 data collator
236
+ texts = []
237
+ images = []
238
+ for example in examples:
239
+ image = example["images"]
240
+ messages = example["messages"]
241
+ text = self.processor.apply_chat_template(messages, add_generation_prompt=False)
242
+ texts.append(text.strip())
243
+ images.append(image)
244
+ inputs = self.processor(text=texts, images=images, return_tensors="pt", padding=True)
245
+
246
+ generate_ids = self.model.generate(**inputs, max_new_tokens=self.config.num_actions)
247
+ generated_text = self.processor.batch_decode(generate_ids, skip_special_tokens=True)
248
+
249
+ if return_traj:
250
+ return self.action_tokenizer.detokenize(generated_text)
251
+ else:
252
+ return generated_text
253
+
254
+ @torch.no_grad()
255
+ def __call__(
256
+ self,
257
+ message_list: list[list[dict]],
258
+ images_list: list[list[Image.Image]],
259
+ return_traj: Optional[bool] = False,
260
+ ):
261
+ """
262
+ call model with message and images
263
+
264
+ Args:
265
+ message_list: list of messages, [B, messages]
266
+ images_list: list of images, [B, images]
267
+ return_traj: return trajectory if True
268
+ """
269
+
270
+ # we don't use batch inference for run model worker
271
+ if len(message_list) != 1:
272
+ raise ValueError("No batch api call allowed for Idefics2Pipeline")
273
+
274
+ message = message_list[0]
275
+ images = images_list[0]
276
+ prompt = self.processor.apply_chat_template(message, add_generation_prompt=True)
277
+ prompt.replace("<end_of_utterance>", "")
278
+ # add space to match the training data
279
+ prompt = prompt + " "
280
+ inputs = self.processor(text=prompt, images=images, return_tensors="pt", padding=True)
281
+
282
+ device = self.model.device
283
+ inputs = {k: v.to(device) for k, v in inputs.items()}
284
+
285
+ generate_ids = self.model.generate(
286
+ **inputs, max_new_tokens=self.config.train_additional_cfg.num_actions, top_k=1
287
+ )
288
+ generated_texts = self.processor.batch_decode(generate_ids, skip_special_tokens=True)
289
+ if return_traj:
290
+ pred_action = re.findall(r"<ACTION_(\d+)>", generated_texts[0])
291
+ # pred_action = pred_action if len(pred_action) == self.config.num_actions else [-1] * self.config.num_actions
292
+ pred_action = np.array(pred_action, dtype=np.int64)
293
+ return self.action_tokenizer.detokenize(pred_action).tolist()
294
+ else:
295
+ return generated_texts
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ transformers==4.46.1
2
+ datasets==3.1.0
3
+ pillow==10.4.0
4
+ numpy==2.1.3
5
+ torch==2.4.0
6
+ pydantic==2.9.2
7
+ scikit-learn==1.5.2
8
+ matplotlib==3.9.3
src/office/0a1f277a93fed629365ac5863c20c64e_frontview_6.0.png ADDED
src/office/0a1f277a93fed629365ac5863c20c64e_map_6.0.png ADDED
src/orchard/d578264e1e51cc5b8f0e496ab381cee4_frontview_79.0.png ADDED
src/orchard/d578264e1e51cc5b8f0e496ab381cee4_map_79.0.png ADDED
src/road/3cfce98ab33a3dc8d43584d5a7039cf5_frontview_8.75.png ADDED
src/road/3cfce98ab33a3dc8d43584d5a7039cf5_map_8.75.png ADDED
src/sidewalk/2d0dde2a98083b7d60b24651d37532dc_frontview_4.0.png ADDED
src/sidewalk/2d0dde2a98083b7d60b24651d37532dc_map_4.0.png ADDED