#!/usr/bin/env python # -*- coding: utf-8 -*- # Import necessary libraries import cv2 import argparse import numpy as np # Parse command line arguments parser = argparse.ArgumentParser(description='Display images with optional labels.') parser.add_argument('images', nargs='+', help='List of image file paths') parser.add_argument('--labels', nargs='+', default=[], help='List of optional labels for each image') parser.add_argument('--output', default='output.jpg', help='Output file name for the generated plot') parser.add_argument('--rows', type=int, default=1, help='Number of rows to display the images') parser.add_argument('--row-labels', nargs='+', default=[], help='List of optional labels for each row') parser.add_argument('--column-labels', nargs='+', default=[], help='List of optional labels for each column') args = parser.parse_args() # Function to save images with optional labels as a plot with a white section on top def save_image_plot(images, labels=None, output='output.jpg', rows=1, row_labels=None, column_labels=None): if not labels: labels = [''] * len(images) if not row_labels: row_labels = [''] * rows if not column_labels: column_labels = [''] * ((len(images) + rows - 1) // rows) # Validate the number of labels matches the number of images if len(labels) != len(images): print("Error: Number of labels should match the number of images.") return # Validate the number of row labels matches the number of rows if len(row_labels) != rows: print("Error: Number of row labels should match the number of rows.") return # Validate the number of column labels matches the number of columns cols = (len(images) + rows - 1) // rows if len(column_labels) != cols: print("Error: Number of column labels should match the number of columns.") return # Read the first image to determine dimensions first_image = cv2.imread(images[0]) # Define canvas dimensions image_height = first_image.shape[0] image_width = first_image.shape[1] # Only add top_padding if we have labels other than column_labels has_other_labels = bool(labels and any(labels)) or bool(row_labels and any(row_labels)) #top_padding = 80 if has_other_labels else (80 if row_labels else 80) top_padding = 50 # Adjust canvas height to remove padding between rows when only column labels are present if not has_other_labels and column_labels: canvas_height = image_height * rows + top_padding # Only add padding at the top else: canvas_height = (image_height + top_padding) * rows canvas_width = image_width * cols # Add extra width for row labels if they are provided left_padding = 40 if any(label.strip() for label in row_labels) else 0 canvas_width += left_padding canvas = np.ones((canvas_height, canvas_width, 3), dtype=np.uint8) * 255 # Add column labels (seed labels) at the top for i, label in enumerate(column_labels): x_pos = i * image_width + left_padding cv2.putText(canvas, label, (x_pos, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2, cv2.LINE_AA) # Iterate through images and labels for i, img_path in enumerate(images): image = cv2.imread(img_path) # Calculate the row and column indices for the current image row = i // cols col = i % cols # Add row label on the left side if provided if row_labels[row].strip(): cv2.putText(canvas, row_labels[row], (0, row * (image_height + top_padding) + 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2, cv2.LINE_AA) # Add seed label centered on the image seed_label = column_labels[col] (text_width, text_height), _ = cv2.getTextSize(seed_label, cv2.FONT_HERSHEY_SIMPLEX, 1, 2) x_seed = col * image_width + left_padding + (image_width - text_width) // 2 y_seed = row * (image_height + top_padding) + top_padding + (image_height - text_height) // 2 cv2.putText(canvas, seed_label, (x_seed, y_seed), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2, cv2.LINE_AA) # Adjust y_start calculation to remove padding between rows when only column labels are present if not has_other_labels and column_labels: y_start = row * image_height + top_padding y_end = y_start + image_height else: y_start = row * (image_height + top_padding) + top_padding y_end = y_start + image_height x_start = col * image_width + left_padding x_end = x_start + image_width canvas[y_start:y_end, x_start:x_end, :] = image # Save the generated plot cv2.imwrite(output, canvas) print(f"Generated plot saved as {output}") # Execute the function with command line arguments save_image_plot(args.images, args.labels, args.output, args.rows, args.row_labels, args.column_labels)