import spaces import os import gradio as gr from time import sleep from signal import SIGTERM from psutil import process_iter from settings import GRAND3D_Settings from utils import list_dirs import open3d as o3d from copy import deepcopy import numpy as np import re from bs4 import BeautifulSoup import logging # The following line sets the root logger level as well. # It's equivalent to both previous statements combined: logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) from session import Session from model import load_model_and_dataloader, get_model_response # Load model and tokenizer once at the start model_path = "checkpoints/lora_grounded_obj_ref_checkpoint-4896" model_base = "checkpoints/llava-llama-2-7b-chat-lightning-preview" load_8bit = False load_4bit = False load_bf16 = True scene_to_obj_mapping = "data/predicted_scene_data_update_5.json" # scene_to_obj_mapping = "data/scanrefer_ground_truth_scene_graph.json" max_new_tokens = 5000 obj_context_feature_type = "text" tokenizer, model, data_loader = load_model_and_dataloader( model_path=model_path, model_base=model_base, load_8bit=load_8bit, load_4bit=load_4bit, load_bf16=load_bf16, scene_to_obj_mapping=scene_to_obj_mapping, ) def get_chatbot_response(user_chat_input, scene_id): # Get the response from the model prompt, response = get_model_response( model=model, tokenizer=tokenizer, data_loader=data_loader, scene_id=scene_id, user_input=user_chat_input, max_new_tokens=max_new_tokens, temperature=0.2, top_p=0.9 ) return scene_id, prompt, response # def get_chatbot_response(user_chat_input): # # Get the response from the chatbot # scene_id = "scene0643_00" # scene_graph = """ # Object-centric context: : {'category': 'door', 'centroid': '[0.35, 1.99, 1.11]', 'extent': '[0.68, 0.65, 2.11]'}; : {'category': 'ceiling', 'centroid': '[1.04, -1.39, 2.68]', 'extent': '[0.18, 0.90, 0.05]'}; : {'category': 'ceiling', 'centroid': '[0.77, 2.09, 2.65]', 'extent': '[0.94, 0.86, 0.11]'}; : {'category': 'trash can', 'centroid': '[-0.61, -2.16, 0.21]', 'extent': '[0.42, 0.36, 0.41]'}; : {'category': 'chair', 'centroid': '[0.35, -1.35, 0.50]', 'extent': '[0.46, 0.47, 0.94]'}; : {'category': 'trash can', 'centroid': '[-0.22, -2.13, 0.24]', 'extent': '[0.40, 0.28, 0.39]'}; : {'category': 'cabinet', 'centroid': '[-1.24, 0.00, 0.58]', 'extent': '[0.61, 0.57, 0.79]'}; : {'category': 'cup', 'centroid': '[0.62, 0.23, 0.77]', 'extent': '[0.14, 0.14, 0.08]'}; : {'category': 'window', 'centroid': '[-0.35, -2.87, 1.13]', 'extent': '[2.05, 0.60, 1.07]'}; : {'category': 'file cabinet', 'centroid': '[0.40, -1.97, 0.39]', 'extent': '[0.40, 0.66, 0.73]'}; : {'category': 'monitor', 'centroid': '[0.92, -1.51, 0.97]', 'extent': '[0.25, 0.57, 0.47]'}; : {'category': 'chair', 'centroid': '[0.34, 0.59, 0.43]', 'extent': '[0.65, 0.64, 0.94]'}; : {'category': 'desk', 'centroid': '[0.64, 0.75, 0.57]', 'extent': '[0.76, 1.60, 0.82]'}; : {'category': 'chair', 'centroid': '[0.55, -0.33, 0.48]', 'extent': '[0.60, 0.60, 0.87]'}; : {'category': 'office chair', 'centroid': '[-0.28, 1.56, 0.46]', 'extent': '[0.67, 0.55, 1.02]'}; : {'category': 'office chair', 'centroid': '[-0.86, -1.53, 0.43]', 'extent': '[0.54, 0.64, 0.97]'}; : {'category': 'chair', 'centroid': '[-0.28, 1.56, 0.46]', 'extent': '[0.67, 0.55, 1.02]'}; : {'category': 'monitor', 'centroid': '[0.98, 0.56, 1.05]', 'extent': '[0.21, 0.60, 0.54]'}; : {'category': 'doorframe', 'centroid': '[-0.17, 2.42, 1.01]', 'extent': '[0.16, 0.18, 1.70]'}; : {'category': 'chair', 'centroid': '[-0.86, -1.53, 0.43]', 'extent': '[0.54, 0.64, 0.97]'}; : {'category': 'bookshelf', 'centroid': '[0.93, 2.00, 1.34]', 'extent': '[0.73, 0.99, 2.60]'}; : {'category': 'office chair', 'centroid': '[0.35, -1.35, 0.50]', 'extent': '[0.46, 0.47, 0.94]'}; : {'category': 'desk', 'centroid': '[-1.23, 1.60, 0.70]', 'extent': '[0.80, 2.01, 0.51]'}; : {'category': 'book', 'centroid': '[0.91, 1.31, 0.89]', 'extent': '[0.34, 0.32, 0.30]'}; : {'category': 'desk', 'centroid': '[-1.24, -1.12, 0.54]', 'extent': '[0.79, 1.88, 0.85]'}; : {'category': 'desk', 'centroid': '[0.63, -1.51, 0.53]', 'extent': '[0.81, 1.97, 0.85]'}; : {'category': 'calendar', 'centroid': '[-1.72, -0.44, 1.40]', 'extent': '[0.07, 0.88, 0.83]'}; : {'category': 'office chair', 'centroid': '[0.34, 0.59, 0.43]', 'extent': '[0.65, 0.64, 0.94]'}; : {'category': 'file cabinet', 'centroid': '[-1.02, -0.76, 0.47]', 'extent': '[0.58, 0.75, 0.81]'}; : {'category': 'cup', 'centroid': '[-1.26, -1.65, 0.78]', 'extent': '[0.10, 0.12, 0.04]'}; : {'category': 'keyboard', 'centroid': '[0.55, 0.84, 0.73]', 'extent': '[0.22, 0.15, 0.03]'} # """ # response = """ # a

brown wooden office desk

[] on the left to the

gray shelf

[].
These sentences refer to

the brown wooden office desk

[].
# """ # return scene_id, scene_graph, response # Resetting to blank def reset_textbox(): return gr.update(value="") # to set a component as visible=False def set_visible_false(): return gr.update(visible=False) # to set a component as visible=True def set_visible_true(): return gr.update(visible=True) def change_scene_or_system_prompt(dropdown_scene_selection: str): # reset model_3d, chatbot_for_display, chat_counter, server_status_code new_session_state = Session.create_for_scene(dropdown_scene_selection) file_name = f"{dropdown_scene_selection}.obj" print(os.path.join(GRAND3D_Settings.data_path, dropdown_scene_selection, file_name)) return ( new_session_state, os.path.join(GRAND3D_Settings.data_path, dropdown_scene_selection, file_name), None, new_session_state.chat_history_for_display, ) def cylinder_frame(p0, p1): """Calculate the transformation matrix to position a unit cylinder between two points.""" direction = np.asarray(p1) - np.asarray(p0) length = np.linalg.norm(direction) direction /= length # Computing rotation matrix using Rodrigues' formula rot_axis = np.cross([0, 0, 1], direction) rot_angle = np.arccos(np.dot([0, 0, 1], direction)) rot_matrix = o3d.geometry.get_rotation_matrix_from_axis_angle(rot_axis * rot_angle) # Translation translation = (np.asarray(p0) + np.asarray(p1)) / 2 transformation = np.eye(4) transformation[:3, :3] = rot_matrix transformation[:3, 3] = translation scaling = np.eye(4) scaling[2, 2] = length transformation = np.matmul(transformation, scaling) return transformation def create_cylinder_mesh(p0, p1, color, radius=0.02, resolution=20, split=1): """Create a colored cylinder mesh between two points p0 and p1.""" cylinder = o3d.geometry.TriangleMesh.create_cylinder( radius=radius, height=1, resolution=resolution, split=split ) transformation = cylinder_frame(p0, p1) cylinder.transform(transformation) # Apply color cylinder.paint_uniform_color(color) return cylinder def prettify_mesh_for_gradio(mesh): # Define the transformation matrix T = np.array([[0, -1, 0, 0], [0, 0, 1, 0], [-1, 0, 0, 0], [0, 0, 0, 1]]) # Apply the transformation mesh.transform(T) mesh.scale(10.0, center=mesh.get_center()) bright_factor = 1 # Adjust this factor to get the desired brightness mesh.vertex_colors = o3d.utility.Vector3dVector( np.clip(np.asarray(mesh.vertex_colors) * bright_factor, 0, 1) ) return mesh def create_bbox(center, extents, color=[1, 0, 0], radius=0.02): """Create a colored bounding box with given center, extents, and line thickness.""" # ... [The same code as before to define corners and lines] ... print(extents) print(type(extents)) extents = extents.replace("[", "").replace("]", "") center = center.replace("[", "").replace("]", "") extents = [float(x.strip()) for x in extents.split(",")] center = [float(x.strip()) for x in center.split(",")] sx, sy, sz = float(extents[0]), float(extents[1]), float(extents[2]) x_corners = [sx / 2, sx / 2, -sx / 2, -sx / 2, sx / 2, sx / 2, -sx / 2, -sx / 2] y_corners = [sy / 2, -sy / 2, -sy / 2, sy / 2, sy / 2, -sy / 2, -sy / 2, sy / 2] z_corners = [sz / 2, sz / 2, sz / 2, sz / 2, -sz / 2, -sz / 2, -sz / 2, -sz / 2] corners_3d = np.vstack([x_corners, y_corners, z_corners]) corners_3d[0, :] = corners_3d[0, :] + float(center[0]) corners_3d[1, :] = corners_3d[1, :] + float(center[1]) corners_3d[2, :] = corners_3d[2, :] + float(center[2]) corners_3d = np.transpose(corners_3d) lines = [ [0, 1], [1, 2], [2, 3], [3, 0], [4, 5], [5, 6], [6, 7], [7, 4], [0, 4], [1, 5], [2, 6], [3, 7], ] cylinders = [] for line in lines: p0, p1 = corners_3d[line[0]], corners_3d[line[1]] cylinders.append(create_cylinder_mesh(p0, p1, color, radius)) return cylinders def highlight_clusters_in_mesh( centroids_extents_detailed, centroids_extends_refer, mesh, output_dir, output_file_name="highlighted_mesh.glb", ): print("*" * 50) # Visualize the highlighted points by drawing 3D bounding boxes overlay on a mesh old_mesh = deepcopy(mesh) output_path = os.path.join(output_dir, "mesh_vis") if not os.path.exists(output_path): os.makedirs(output_path) # Create a combined mesh to hold both the original and the bounding boxes combined_mesh = o3d.geometry.TriangleMesh() combined_mesh += old_mesh # Draw bounding boxes for each centroid and extent for center, extent in centroids_extents_detailed: print("center: ", center) print("extent: ", extent) bbox = create_bbox(center, extent, color=[0, 0, 1]) # Red color for all boxes for b in bbox: combined_mesh += b for center, extent in centroids_extends_refer: bbox = create_bbox(center, extent, color=[0, 1, 0]) for b in bbox: combined_mesh += b combined_mesh = prettify_mesh_for_gradio(combined_mesh) # Save the combined mesh output_file_path = os.path.join(output_path, output_file_name) o3d.io.write_triangle_mesh( output_file_path, combined_mesh, write_vertex_colors=True ) print("*" * 50) return output_file_path def extract_objects(text): return re.findall(r"", text) # Parse the scene graph into a dictionary def parse_scene_graph(scene_graph): scene_dict = {} matches = re.findall(r": (\{.*?\})", scene_graph) for match in matches: obj_id = f"" obj_data = eval(match[1]) scene_dict[obj_id] = obj_data return scene_dict def get_centroids_extents(obj_list, scene_dict): centroids_extents = [] for obj in obj_list: if obj in scene_dict: centroid = scene_dict[obj]["centroid"] extent = scene_dict[obj]["extent"] centroids_extents.append((centroid, extent)) return centroids_extents @spaces.GPU def language_model_forward( session_state, user_chat_input, top_p, temperature, dropdown_scene ): session_state = Session.create_for_scene(dropdown_scene) session_state.chat_history_for_display.append( (user_chat_input, None) ) # append in a tuple format, first is user input, second is assistant response yield session_state, None, session_state.chat_history_for_display # Load in a 3D model file_name = f"{session_state.scene}.obj" original_model_path = os.path.join( GRAND3D_Settings.data_path, session_state.scene, file_name ) print("original_model_path: ", original_model_path) # Load the GLB mesh mesh = o3d.io.read_triangle_mesh(original_model_path) # get chatbot response scene_id, scene_graph, response = get_chatbot_response(user_chat_input, session_state.scene) assert scene_id == session_state.scene # Ensure the scene ID matches # use scene_graph and response to get centroids and extents # Parse the scene graph into a dictionary scene_dict = parse_scene_graph(scene_graph) print("Model Input: " + str(scene_dict)) print("=" * 50) print("Model Response: " + response) # Parse the response to get detailed and refer expression groundings soup = BeautifulSoup(response, "html.parser") detailed_grounding_html = str(soup.find("detailed_grounding")) refer_expression_grounding_html = str(soup.find("refer_expression_grounding")) # Extract objects from both sections detailed_objects = extract_objects(detailed_grounding_html) refer_objects = extract_objects(refer_expression_grounding_html) # Extract objects from both sections print("detailed_objects: ", detailed_objects) print("refer_objects: ", refer_objects) # Perform set subtraction to get remaining objects remaining_objects = list(set(detailed_objects) - set(refer_objects)) print("remaining_objects: ", remaining_objects) centroids_extents_detailed = get_centroids_extents(remaining_objects, scene_dict) print("centroids_extents_detailed: ", centroids_extents_detailed) centroids_extents_refer = get_centroids_extents(refer_objects, scene_dict) print("centroids_extents_refer: ", centroids_extents_refer) # Define your centroids and extents here (example data) # Highlight clusters in the mesh and save it session_output_dir = session_state.get_session_output_dir() highlighted_model_path = highlight_clusters_in_mesh( centroids_extents_detailed, centroids_extents_refer, mesh, session_output_dir, output_file_name="highlighted_model.glb", ) # Update the chat history with the response last_turn = session_state.chat_history_for_display[-1] # first is user input, second is assistant response last_turn = (last_turn[0], response) session_state.chat_history_for_display[-1] = last_turn session_state.save() # save the session state yield session_state, highlighted_model_path, session_state.chat_history_for_display title = """

🤖 3D-GRAND: Towards Better Grounding and Less Hallucination for 3D-LLMs 🚀

[Project Page] [3D-GRAND Data] [3D-POPE Data]

""" # Modifying existing Gradio Theme # theme = gr.themes.Soft( # primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.pink # ) with gr.Blocks(theme=gr.themes.Soft()) as demo: session_state = gr.State(Session.create) gr.HTML(title) with gr.Column(): with gr.Row(): with gr.Column(scale=5): dropdown_scene = gr.Dropdown( choices=list_dirs(GRAND3D_Settings.data_path), value=GRAND3D_Settings.default_scene, interactive=True, label="Select a scene", ) model_3d = gr.Model3D( value=os.path.join( GRAND3D_Settings.data_path, GRAND3D_Settings.default_scene, f"{GRAND3D_Settings.default_scene}.obj", ), clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model", camera_position=(-50, 65, 10), zoom_speed=10.0, ) gr.HTML( """
👆 SCROLL or DRAG on the 3D Model to zoom in/out and rotate. Press CTRL and DRAG to pan.
""" ) gr.HTML( """
👇 When grounding finishes, the grounding result will be displayed below.
""" ) model_3d_grounding_result = gr.Model3D( clear_color=[0.0, 0.0, 0.0, 0.0], label="Grounding Result", zoom_speed=15.0, ) gr.HTML( """
= Landmark  
= Chosen Target
""" ) with gr.Column(scale=5): chat_history_for_display = gr.Chatbot( value=[(None, GRAND3D_Settings.INITIAL_MSG_FOR_DISPLAY)], label="Chat Assistant", height=510, render_markdown=False, sanitize_html=False, ) with gr.Row(): with gr.Column(scale=8): user_chat_input = gr.Textbox( placeholder="I want to find the chair near the table", show_label=False, ) with gr.Column(scale=1, min_width=0): send_button = gr.Button("Send", variant="primary") with gr.Column(scale=1, min_width=0): clear_button = gr.Button("Clear") with gr.Row(): with gr.Accordion(label="Examples for user message:", open=True): gr.Examples( examples=[ ["The TV on the drawer, opposing the bed."], ["the desk next to the window"] ], inputs=user_chat_input, ) with gr.Accordion("Parameters", open=False, visible=False): top_p = gr.Slider( minimum=0, maximum=1.0, value=1.0, step=0.05, interactive=True, label="Top-p (nucleus sampling)", ) temperature = gr.Slider( minimum=0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Temperature", ) # gr.Markdown("### Terms of Service") # gr.HTML( # """By using this service, users are required to agree to the following terms: # The service is a research preview intended for non-commercial use only. # The service may collect user dialogue data for future research.""" # ) # Event handling dropdown_scene.change( fn=change_scene_or_system_prompt, inputs=[dropdown_scene], outputs=[session_state, model_3d, model_3d_grounding_result, chat_history_for_display], ) clear_button.click( fn=change_scene_or_system_prompt, inputs=[dropdown_scene], outputs=[session_state, model_3d, model_3d_grounding_result, chat_history_for_display], ) user_chat_input.submit( fn=language_model_forward, inputs=[session_state, user_chat_input, top_p, temperature, dropdown_scene], outputs=[session_state, model_3d_grounding_result, chat_history_for_display], ) send_button.click( fn=language_model_forward, inputs=[session_state, user_chat_input, top_p, temperature, dropdown_scene], outputs=[session_state, model_3d_grounding_result, chat_history_for_display], ) send_button.click(reset_textbox, [], [user_chat_input]) user_chat_input.submit(reset_textbox, [], [user_chat_input]) sleep_time = 2 port = 7011 for x in range(1, 10): # try 8 times try: # put your logic here gr.close_all() demo.queue( max_size=20, ).launch( # debug=True, # server_name="0.0.0.0", # server_port=port, # share=True ) except OSError: for proc in process_iter(): for conns in proc.connections(kind="inet"): if conns.laddr.port == port: proc.send_signal(SIGTERM) # or SIGKILL print(f"Retrying {x} time...") pass sleep(sleep_time) # wait for 2 seconds before trying to fetch the data again sleep_time *= 2 # exponential backoff