import gradio as gr import numpy as np import tifffile import matplotlib.pyplot as plt from prediction import predict_mask # Placeholder for your 3D model def process_3d_image(image, resx, resy, resz): # Dummy model implementation: Replace with your actual model logic binary_mask = predict_mask(image, resx, resy, resz) return binary_mask def auximread(filepath): image = tifffile.imread(filepath) # The output image should be (X,Y,Z) original_0 = np.shape(image)[0] original_1 = np.shape(image)[1] original_2 = np.shape(image)[2] index_min = np.argmin([original_0, original_1, original_2]) if index_min == 0: image = image.transpose(1, 2, 0) elif index_min == 1: image = image.transpose(0, 2, 1) return image # Function to handle file input and processing def process_file(file, resx, resy, resz): """ Process the uploaded file and return the binary mask. """ if file.name.endswith(".tif"): # Load .tif file as a 3D numpy array image = auximread(file.name) else: raise ValueError("Unsupported file format. Please upload a .tif or .czi file.") # Ensure image is 3D if len(image.shape) != 3: raise ValueError("Input image is not 3D.") # Process image through the model binary_mask = process_3d_image(image, resx, resy, resz) # Save binary mask to a .tif file to return output_path = "output_mask.tif" tifffile.imwrite(output_path, binary_mask) return image, binary_mask, output_path # Function to generate the slice visualization def visualize_slice(image, mask, slice_index): """ Visualizes a 2D slice of the image and the corresponding mask at the given index. """ fig, axes = plt.subplots(1, 2, figsize=(12, 6)) # Extract the 2D slices image_slice = image[:, :, slice_index] mask_slice = mask[:, :, slice_index] # Plot image slice axes[0].imshow(image_slice, cmap="gray") axes[0].set_title("Image Slice") axes[0].axis("off") # Plot mask slice axes[1].imshow(mask_slice, cmap="gray") axes[1].set_title("Mask Slice") axes[1].axis("off") # Return the plot as a Gradio-compatible output plt.tight_layout() plt.close(fig) return fig # Variables to store the processed image and mask processed_image = None processed_mask = None def segment_button_click(file, resx, resy, resz): global processed_image, processed_mask processed_image, processed_mask, output_path = process_file(file, resx, resy, resz) num_slices = processed_image.shape[2] return "Segmentation completed! Use the slider to explore slices.", output_path, gr.update(visible=True, maximum=num_slices - 1) def update_visualization(slice_index): if processed_image is None or processed_mask is None: raise ValueError("Please process an image first by clicking the Segment button.") return visualize_slice(processed_image, processed_mask, slice_index) # Gradio Interface with gr.Blocks() as iface: gr.Markdown("""# 3DVascNet: Retinal Blood Vessel Segmentation Upload a 3D image in .tif format. Click the **Segment** button to process the image and generate a 3D binary mask. Use the slider to navigate through the 2D slices. This is the official implementation of 3DVascNet, described in this paper: https://www.ahajournals.org/doi/10.1161/ATVBAHA.124.320672. The raw code is available at https://github.com/HemaxiN/3DVascNet. """) # Input fields for resolution in micrometers with gr.Row(): resx_input = gr.Number(value=0.333, label="Resolution in X (µm)", precision=3) resy_input = gr.Number(value=0.333, label="Resolution in Y (µm)", precision=3) resz_input = gr.Number(value=0.5, label="Resolution in Z (µm)", precision=3) with gr.Row(): file_input = gr.File(label="Upload 3D Image (.tif)") segment_button = gr.Button("Segment") status_output = gr.Textbox(label="Status", interactive=False) download_output = gr.File(label="Download Binary Mask (.tif)") with gr.Row(): slice_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Slice Index", interactive=True, visible=False) visualization_output = gr.Plot(label="2D Slice Visualization") # Button click triggers segmentation segment_button.click(segment_button_click, inputs=[file_input, resx_input, resy_input, resz_input], outputs=[status_output, download_output, slice_slider]) # Slider changes trigger visualization updates slice_slider.change(update_visualization, inputs=slice_slider, outputs=visualization_output) if __name__ == "__main__": iface.launch()