Spaces:
Sleeping
Sleeping
lastdefiance20
commited on
Commit
·
d942a8d
1
Parent(s):
3dcb152
First commit
Browse files- app.py +126 -0
- canvas.py +295 -0
- requirements.txt +8 -0
- src/office/0a1f277a93fed629365ac5863c20c64e_frontview_6.0.png +0 -0
- src/office/0a1f277a93fed629365ac5863c20c64e_map_6.0.png +0 -0
- src/orchard/d578264e1e51cc5b8f0e496ab381cee4_frontview_79.0.png +0 -0
- src/orchard/d578264e1e51cc5b8f0e496ab381cee4_map_79.0.png +0 -0
- src/road/3cfce98ab33a3dc8d43584d5a7039cf5_frontview_8.75.png +0 -0
- src/road/3cfce98ab33a3dc8d43584d5a7039cf5_map_8.75.png +0 -0
- src/sidewalk/2d0dde2a98083b7d60b24651d37532dc_frontview_4.0.png +0 -0
- src/sidewalk/2d0dde2a98083b7d60b24651d37532dc_map_4.0.png +0 -0
app.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import numpy as np
|
3 |
+
import gradio as gr
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
|
6 |
+
from io import BytesIO
|
7 |
+
from PIL import Image
|
8 |
+
from canvas import Idefics2Pipeline
|
9 |
+
|
10 |
+
|
11 |
+
def run_canvas(front_view, map_view, prompt):
|
12 |
+
pipeline = Idefics2Pipeline.from_pretrained(
|
13 |
+
"maum-ai/CANVAS-S"
|
14 |
+
)
|
15 |
+
messages = [
|
16 |
+
{"role": "system", "content": [{"type": "text", "text": prompt}]},
|
17 |
+
{
|
18 |
+
"role": "user",
|
19 |
+
"content": [{"type": "image"}, {"type": "image"}],
|
20 |
+
},
|
21 |
+
]
|
22 |
+
|
23 |
+
print(front_view)
|
24 |
+
|
25 |
+
images = [front_view, map_view]
|
26 |
+
pred = pipeline([messages], [images], return_traj=False)
|
27 |
+
pred_action = re.findall(r"<ACTION_(\d+)>", pred[0])
|
28 |
+
pred_action = np.array(pred_action, dtype=np.int64)
|
29 |
+
print(pred_action)
|
30 |
+
pred_action_odom = pipeline.action_tokenizer.detokenize(pred_action).tolist()
|
31 |
+
print(pred_action_odom)
|
32 |
+
|
33 |
+
# Create a figure and axes
|
34 |
+
fig, axes = plt.subplots(1, 1, figsize=(8, 6))
|
35 |
+
|
36 |
+
# Scale factor for the arrow
|
37 |
+
scale_factor = 0.2
|
38 |
+
|
39 |
+
axes.plot(0, 0, marker="o", color="black", markersize=10)
|
40 |
+
axes.invert_xaxis()
|
41 |
+
|
42 |
+
for i, center in zip(pred_action, pred_action_odom):
|
43 |
+
x, y, yaw = center
|
44 |
+
axes.plot(y, x, marker="^", color="blue")
|
45 |
+
axes.arrow(
|
46 |
+
y,
|
47 |
+
x,
|
48 |
+
np.sin(yaw) * scale_factor,
|
49 |
+
np.cos(yaw) * scale_factor,
|
50 |
+
head_width=scale_factor * 0.3,
|
51 |
+
head_length=scale_factor * 0.3,
|
52 |
+
fc="k",
|
53 |
+
ec="k",
|
54 |
+
)
|
55 |
+
axes.text(y, x, f"{i}", fontsize=10)
|
56 |
+
axes.axis("equal")
|
57 |
+
axes.grid(True)
|
58 |
+
|
59 |
+
buf = BytesIO()
|
60 |
+
fig.savefig(buf, format="png")
|
61 |
+
buf.seek(0) # Rewind the buffer to the beginning
|
62 |
+
pil_img = Image.open(buf)
|
63 |
+
|
64 |
+
return pil_img
|
65 |
+
|
66 |
+
examples = [
|
67 |
+
["src/office/0a1f277a93fed629365ac5863c20c64e_frontview_6.0.png", "src/office/0a1f277a93fed629365ac5863c20c64e_map_6.0.png", """You are an indoor food-serving robot.
|
68 |
+
|
69 |
+
You must follow these driving instructions:
|
70 |
+
1. You must avoid collisions.
|
71 |
+
2. You should prioritize reaching the final destination.
|
72 |
+
3. You should follow the Trajectory Instruction.
|
73 |
+
a. If the Trajectory Instruction cannot be followed due to any obstacles, you should deviate to bypass the obstacle.
|
74 |
+
b. You should try to evade any identifiable obstacles.
|
75 |
+
4. You should maintain a constant driving speed.
|
76 |
+
a. Indoors, you should drive at a speed of 3-4km/h.
|
77 |
+
5. You must slow down(2km/h or lower) if a human or obstacle comes within 1.5m radius.
|
78 |
+
a. You must slow down(2km/h or lower) in areas where a human could suddenly appear from a blind spot."""],
|
79 |
+
["src/orchard/d578264e1e51cc5b8f0e496ab381cee4_frontview_79.0.png", "src/orchard/d578264e1e51cc5b8f0e496ab381cee4_map_79.0.png", """You are an outdoor speed-sprayer robot.
|
80 |
+
|
81 |
+
You must follow these driving instructions:
|
82 |
+
1. You must avoid collisions.
|
83 |
+
2. You should prioritize reaching the final destination.
|
84 |
+
3. You should follow the Trajectory Instruction.
|
85 |
+
a. If the Trajectory Instruction cannot be followed due to any obstacles, you should deviate to bypass the obstacle.
|
86 |
+
b. You should try to evade any identifiable obstacles.
|
87 |
+
4. You should maintain a constant driving speed."""],
|
88 |
+
["src/sidewalk/2d0dde2a98083b7d60b24651d37532dc_frontview_4.0.png", "src/sidewalk/2d0dde2a98083b7d60b24651d37532dc_map_4.0.png", """You are an outdoor last mile delivery robot.
|
89 |
+
|
90 |
+
You must follow these driving instructions:
|
91 |
+
1. You must avoid collisions.
|
92 |
+
2. You should prioritize reaching the final destination.
|
93 |
+
3. You should follow the Trajectory Instruction.
|
94 |
+
a. If the Trajectory Instruction cannot be followed due to any obstacles, you should deviate to bypass the obstacle.
|
95 |
+
b. You should try to evade any identifiable obstacles.
|
96 |
+
4. You should maintain a constant driving speed.
|
97 |
+
5. You must drive on the sidewalk.
|
98 |
+
a. If you need to cross the road, you must use the crosswalk."""],
|
99 |
+
["src/road/3cfce98ab33a3dc8d43584d5a7039cf5_frontview_8.75.png", "src/road/3cfce98ab33a3dc8d43584d5a7039cf5_map_8.75.png", """You are an outdoor self-driving robot taxi.
|
100 |
+
|
101 |
+
You must follow these driving instructions:
|
102 |
+
1. You must avoid collisions.
|
103 |
+
2. You should prioritize reaching the final destination.
|
104 |
+
3. You should follow the Trajectory Instruction.
|
105 |
+
a. If the Trajectory Instruction cannot be followed due to any obstacles, you should deviate to bypass the obstacle.
|
106 |
+
b. You should try to evade any identifiable obstacles.
|
107 |
+
4. You should maintain a constant driving speed.
|
108 |
+
5. You must drive on the road.
|
109 |
+
a. You should drive according to the left-hand-traffic law.
|
110 |
+
6. You should slow down before entering intersections, speed bumps, and crosswalks."""],
|
111 |
+
]
|
112 |
+
|
113 |
+
demo = gr.Interface(
|
114 |
+
fn = run_canvas,
|
115 |
+
inputs = [
|
116 |
+
gr.Image(label="front_view", type="pil"),
|
117 |
+
gr.Image(label="map_view", type="pil"),
|
118 |
+
gr.Textbox(label="prompt")
|
119 |
+
],
|
120 |
+
outputs = gr.Image(label="generated waypoint"),
|
121 |
+
title="CANVAS Demo",
|
122 |
+
description="This is the demo of the CANVAS-S model from CANVAS: Commonsense-Aware Navigation System for Intuitive Human-Robot Interaction",
|
123 |
+
examples=examples
|
124 |
+
)
|
125 |
+
|
126 |
+
demo.launch()
|
canvas.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
from transformers import Idefics2ForConditionalGeneration, Idefics2Processor, PreTrainedModel, ProcessorMixin
|
6 |
+
from typing import Optional
|
7 |
+
import re
|
8 |
+
|
9 |
+
import logging
|
10 |
+
import pickle
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
from matplotlib import pyplot as plt
|
15 |
+
from sklearn.cluster import KMeans
|
16 |
+
|
17 |
+
class BaseModelYamlJsonMixin:
|
18 |
+
"""
|
19 |
+
BaseModel with helper methods for loading and saving to yaml/json format.
|
20 |
+
"""
|
21 |
+
|
22 |
+
@classmethod
|
23 |
+
def from_yaml(cls, path: Path):
|
24 |
+
with open(path, "r", encoding="utf-8") as f:
|
25 |
+
return cls(**yaml.safe_load(f))
|
26 |
+
|
27 |
+
def to_yaml(self: BaseModel, path: Path):
|
28 |
+
with open(path, "w", encoding="utf-8") as f:
|
29 |
+
yaml.safe_dump(self.model_dump(), f)
|
30 |
+
|
31 |
+
@classmethod
|
32 |
+
def from_json(cls, path: Path):
|
33 |
+
with open(path, "r", encoding="utf-8") as f:
|
34 |
+
return cls.model_validate_json(f.read())
|
35 |
+
|
36 |
+
def to_json(self: BaseModel, path: Path, indent: int = 4, *args, **kwargs):
|
37 |
+
with open(path, "w", encoding="utf-8") as f:
|
38 |
+
f.write(self.model_dump_json(indent=indent, *args, **kwargs))
|
39 |
+
|
40 |
+
class BaseModelWithYamlJsonFromTo(BaseModel, BaseModelYamlJsonMixin):
|
41 |
+
pass
|
42 |
+
|
43 |
+
class Idefics2TrainAdditionalConfig(BaseModel):
|
44 |
+
"""
|
45 |
+
num_action_tokens (`int`, defaults to `32`):
|
46 |
+
Number of action tokens to add to the tokenizer vocabulary.
|
47 |
+
do_image_splitting (`bool`, *optional*, defaults to `False`):
|
48 |
+
Whether to split the image into a sequence 4 equal sub-images concatenated with the original image. That
|
49 |
+
strategy was first introduced in https://arxiv.org/abs/2311.06607.
|
50 |
+
lora_config (`dict`, defaults to recommended config from https://x.com/danielhanchen/status/1791900967472140583):
|
51 |
+
Configuration for the LoRA model. If it is `None`, the model will not use LoRA.
|
52 |
+
"""
|
53 |
+
|
54 |
+
# must be set to extend vocabulary of model + tokenizer
|
55 |
+
num_action_tokens: int = -1 # it will be overwritten by the processor_config.yml
|
56 |
+
# must be set to be used in pipeline
|
57 |
+
num_actions: int = -1 # it will be overwritten by the processor_config.yml
|
58 |
+
|
59 |
+
do_image_splitting: bool = True
|
60 |
+
freeze_original_vocab: bool = False
|
61 |
+
freeze_vision_model: bool = False
|
62 |
+
freeze_connector: bool = False
|
63 |
+
torch_dtype: str = "bfloat16"
|
64 |
+
lora_config: dict | None = dict(
|
65 |
+
r=256,
|
66 |
+
lora_alpha=512,
|
67 |
+
lora_dropout=0.1,
|
68 |
+
target_modules="all-linear",
|
69 |
+
use_rslora=True,
|
70 |
+
init_lora_weights="gaussian",
|
71 |
+
modules_to_save=["lm_head", "embed_tokens"],
|
72 |
+
)
|
73 |
+
model_name_or_path: str = "HuggingFaceM4/idefics2-8b"
|
74 |
+
|
75 |
+
|
76 |
+
class KMeansActionTokenizer():
|
77 |
+
def __init__(self, action_count: int = 128):
|
78 |
+
self.action_count = action_count
|
79 |
+
self.kmeans = KMeans(n_clusters=self.action_count, random_state=np.random.RandomState(seed=42))
|
80 |
+
|
81 |
+
@property
|
82 |
+
def token_count(self):
|
83 |
+
return self.action_count
|
84 |
+
|
85 |
+
@classmethod
|
86 |
+
def from_pretrained(cls, model_path: str | Path):
|
87 |
+
model_path = Path(model_path)
|
88 |
+
self = cls()
|
89 |
+
with open(model_path / "tokenizer.pkl", "rb") as file:
|
90 |
+
self.kmeans = pickle.load(file)
|
91 |
+
self.action_count = self.kmeans.n_clusters
|
92 |
+
# assert self.action_count == 32
|
93 |
+
return self
|
94 |
+
|
95 |
+
def save_pretrained(self, model_path: str | Path):
|
96 |
+
model_path = Path(model_path)
|
97 |
+
model_path.mkdir(exist_ok=True)
|
98 |
+
with open(model_path / "tokenizer.pkl", "wb") as file:
|
99 |
+
pickle.dump(self.kmeans, file)
|
100 |
+
|
101 |
+
def train(self, actions):
|
102 |
+
self.kmeans.fit(actions)
|
103 |
+
|
104 |
+
def tokenize(self, action, padding=False, max_length=-1, truncation=False):
|
105 |
+
# action: (K, 3) shape, adjusted delta_position and delta_yaw
|
106 |
+
return [i for i in self.kmeans.predict(action)]
|
107 |
+
|
108 |
+
def detokenize(self, tokens):
|
109 |
+
# Token Check
|
110 |
+
check = np.asarray(tokens)
|
111 |
+
in_valid_range = (0 <= check) & (check < self.action_count)
|
112 |
+
if not in_valid_range.all():
|
113 |
+
logging.warning(f"Invalid tokens occur: {tokens}")
|
114 |
+
# If error occurs, return stop action.
|
115 |
+
return np.asarray([[0.0, 0.0, 0.0] for _ in range(len(tokens))])
|
116 |
+
return np.asarray([self.kmeans.cluster_centers_[t] for t in tokens])
|
117 |
+
|
118 |
+
def visualize(self, figset=None):
|
119 |
+
if figset is None:
|
120 |
+
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(12, 16), dpi=300)
|
121 |
+
else:
|
122 |
+
fig, axes = figset
|
123 |
+
FONT = {"fontsize": 20}
|
124 |
+
|
125 |
+
axes[0].set_title("Center", fontdict=FONT)
|
126 |
+
axes[1].set_title("Center_Rot", fontdict=FONT)
|
127 |
+
|
128 |
+
labels = self.kmeans.labels_
|
129 |
+
centers = self.kmeans.cluster_centers_
|
130 |
+
|
131 |
+
# plot center. each center is given as (x, y, yaw). plot point (x,y) and arrow from (x,y) to p', with direction of yaw. consider (x, y)'s scale
|
132 |
+
scale_factor = 0.05
|
133 |
+
for i, center in enumerate(centers):
|
134 |
+
x, y, yaw = center
|
135 |
+
axes[0].plot(x, y, "ro")
|
136 |
+
axes[0].arrow(
|
137 |
+
x,
|
138 |
+
y,
|
139 |
+
np.cos(yaw) * scale_factor,
|
140 |
+
np.sin(yaw) * scale_factor,
|
141 |
+
head_width=scale_factor * 0.3,
|
142 |
+
head_length=scale_factor * 0.3,
|
143 |
+
fc="k",
|
144 |
+
ec="k",
|
145 |
+
)
|
146 |
+
axes[0].text(x, y, f"{i}", fontsize=10)
|
147 |
+
axes[0].axis("equal")
|
148 |
+
axes[0].grid(True)
|
149 |
+
|
150 |
+
# filter centers that are not far from origin in distance 0.3
|
151 |
+
_centers = centers[np.linalg.norm(centers[:, :2], axis=1) < 0.05]
|
152 |
+
# print(f"action near zero: {_centers}")
|
153 |
+
scale_factor = 0.1
|
154 |
+
for center in _centers:
|
155 |
+
x, y, yaw = center
|
156 |
+
axes[1].plot(x, y, "ro")
|
157 |
+
axes[1].arrow(
|
158 |
+
x,
|
159 |
+
y,
|
160 |
+
np.cos(yaw) * scale_factor,
|
161 |
+
np.sin(yaw) * scale_factor,
|
162 |
+
head_width=scale_factor * 0.3,
|
163 |
+
head_length=scale_factor * 0.3,
|
164 |
+
fc="k",
|
165 |
+
ec="k",
|
166 |
+
)
|
167 |
+
axes[1].axis("equal")
|
168 |
+
axes[1].grid(True)
|
169 |
+
|
170 |
+
return fig, axes
|
171 |
+
|
172 |
+
|
173 |
+
class Idefics2PipelineConfig(BaseModelWithYamlJsonFromTo):
|
174 |
+
pipeline_class: str = "Idefics2Pipeline"
|
175 |
+
train_additional_cfg: Idefics2TrainAdditionalConfig
|
176 |
+
|
177 |
+
|
178 |
+
class Idefics2Pipeline():
|
179 |
+
def __init__(
|
180 |
+
self,
|
181 |
+
model: PreTrainedModel,
|
182 |
+
processor: ProcessorMixin,
|
183 |
+
action_tokenizer: KMeansActionTokenizer,
|
184 |
+
config: Idefics2PipelineConfig,
|
185 |
+
):
|
186 |
+
self.model = model
|
187 |
+
self.processor = processor
|
188 |
+
self.action_tokenizer = action_tokenizer
|
189 |
+
self.config = config
|
190 |
+
|
191 |
+
def save_pretrained(
|
192 |
+
self,
|
193 |
+
save_directory: str,
|
194 |
+
):
|
195 |
+
if not isinstance(save_directory, Path):
|
196 |
+
save_directory = Path(save_directory)
|
197 |
+
self.model.save_pretrained(save_directory)
|
198 |
+
self.processor.save_pretrained(save_directory)
|
199 |
+
self.action_tokenizer.save_pretrained(save_directory)
|
200 |
+
self.config.to_json(f"{save_directory}/pipeline_config.json")
|
201 |
+
|
202 |
+
@classmethod
|
203 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str):
|
204 |
+
if not isinstance(pretrained_model_name_or_path, Path):
|
205 |
+
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
|
206 |
+
|
207 |
+
config = Idefics2PipelineConfig.model_validate_json(
|
208 |
+
(pretrained_model_name_or_path / "pipeline_config.json").read_text()
|
209 |
+
)
|
210 |
+
model = Idefics2ForConditionalGeneration.from_pretrained(pretrained_model_name_or_path)
|
211 |
+
processor = Idefics2Processor.from_pretrained(pretrained_model_name_or_path)
|
212 |
+
model.eval()
|
213 |
+
action_tokenizer = KMeansActionTokenizer.from_pretrained(pretrained_model_name_or_path)
|
214 |
+
return cls(model, processor, action_tokenizer, config)
|
215 |
+
|
216 |
+
def to(self, device):
|
217 |
+
return self.model.to(device)
|
218 |
+
|
219 |
+
@torch.no_grad()
|
220 |
+
def __call__(
|
221 |
+
self,
|
222 |
+
examples: list[dict],
|
223 |
+
return_traj: Optional[bool] = False,
|
224 |
+
):
|
225 |
+
"""
|
226 |
+
call model with examples
|
227 |
+
|
228 |
+
Args:
|
229 |
+
examples: list of example, [B, example]
|
230 |
+
return_traj: return trajectory if True
|
231 |
+
"""
|
232 |
+
|
233 |
+
raise NotImplementedError("Not implemented yet")
|
234 |
+
|
235 |
+
# same as idefics2 data collator
|
236 |
+
texts = []
|
237 |
+
images = []
|
238 |
+
for example in examples:
|
239 |
+
image = example["images"]
|
240 |
+
messages = example["messages"]
|
241 |
+
text = self.processor.apply_chat_template(messages, add_generation_prompt=False)
|
242 |
+
texts.append(text.strip())
|
243 |
+
images.append(image)
|
244 |
+
inputs = self.processor(text=texts, images=images, return_tensors="pt", padding=True)
|
245 |
+
|
246 |
+
generate_ids = self.model.generate(**inputs, max_new_tokens=self.config.num_actions)
|
247 |
+
generated_text = self.processor.batch_decode(generate_ids, skip_special_tokens=True)
|
248 |
+
|
249 |
+
if return_traj:
|
250 |
+
return self.action_tokenizer.detokenize(generated_text)
|
251 |
+
else:
|
252 |
+
return generated_text
|
253 |
+
|
254 |
+
@torch.no_grad()
|
255 |
+
def __call__(
|
256 |
+
self,
|
257 |
+
message_list: list[list[dict]],
|
258 |
+
images_list: list[list[Image.Image]],
|
259 |
+
return_traj: Optional[bool] = False,
|
260 |
+
):
|
261 |
+
"""
|
262 |
+
call model with message and images
|
263 |
+
|
264 |
+
Args:
|
265 |
+
message_list: list of messages, [B, messages]
|
266 |
+
images_list: list of images, [B, images]
|
267 |
+
return_traj: return trajectory if True
|
268 |
+
"""
|
269 |
+
|
270 |
+
# we don't use batch inference for run model worker
|
271 |
+
if len(message_list) != 1:
|
272 |
+
raise ValueError("No batch api call allowed for Idefics2Pipeline")
|
273 |
+
|
274 |
+
message = message_list[0]
|
275 |
+
images = images_list[0]
|
276 |
+
prompt = self.processor.apply_chat_template(message, add_generation_prompt=True)
|
277 |
+
prompt.replace("<end_of_utterance>", "")
|
278 |
+
# add space to match the training data
|
279 |
+
prompt = prompt + " "
|
280 |
+
inputs = self.processor(text=prompt, images=images, return_tensors="pt", padding=True)
|
281 |
+
|
282 |
+
device = self.model.device
|
283 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
284 |
+
|
285 |
+
generate_ids = self.model.generate(
|
286 |
+
**inputs, max_new_tokens=self.config.train_additional_cfg.num_actions, top_k=1
|
287 |
+
)
|
288 |
+
generated_texts = self.processor.batch_decode(generate_ids, skip_special_tokens=True)
|
289 |
+
if return_traj:
|
290 |
+
pred_action = re.findall(r"<ACTION_(\d+)>", generated_texts[0])
|
291 |
+
# pred_action = pred_action if len(pred_action) == self.config.num_actions else [-1] * self.config.num_actions
|
292 |
+
pred_action = np.array(pred_action, dtype=np.int64)
|
293 |
+
return self.action_tokenizer.detokenize(pred_action).tolist()
|
294 |
+
else:
|
295 |
+
return generated_texts
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers==4.46.1
|
2 |
+
datasets==3.1.0
|
3 |
+
pillow==10.4.0
|
4 |
+
numpy==2.1.3
|
5 |
+
torch==2.4.0
|
6 |
+
pydantic==2.9.2
|
7 |
+
scikit-learn==1.5.2
|
8 |
+
matplotlib==3.9.3
|
src/office/0a1f277a93fed629365ac5863c20c64e_frontview_6.0.png
ADDED
src/office/0a1f277a93fed629365ac5863c20c64e_map_6.0.png
ADDED
src/orchard/d578264e1e51cc5b8f0e496ab381cee4_frontview_79.0.png
ADDED
src/orchard/d578264e1e51cc5b8f0e496ab381cee4_map_79.0.png
ADDED
src/road/3cfce98ab33a3dc8d43584d5a7039cf5_frontview_8.75.png
ADDED
src/road/3cfce98ab33a3dc8d43584d5a7039cf5_map_8.75.png
ADDED
src/sidewalk/2d0dde2a98083b7d60b24651d37532dc_frontview_4.0.png
ADDED
src/sidewalk/2d0dde2a98083b7d60b24651d37532dc_map_4.0.png
ADDED