File size: 4,950 Bytes
c2cc76d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Import necessary libraries
import cv2
import argparse
import numpy as np

# Parse command line arguments
parser = argparse.ArgumentParser(description='Display images with optional labels.')
parser.add_argument('images', nargs='+', help='List of image file paths')
parser.add_argument('--labels', nargs='+', default=[], help='List of optional labels for each image')
parser.add_argument('--output', default='output.jpg', help='Output file name for the generated plot')
parser.add_argument('--rows', type=int, default=1, help='Number of rows to display the images')
parser.add_argument('--row-labels', nargs='+', default=[], help='List of optional labels for each row')
parser.add_argument('--column-labels', nargs='+', default=[], help='List of optional labels for each column')

args = parser.parse_args()

# Function to save images with optional labels as a plot with a white section on top
def save_image_plot(images, labels=None, output='output.jpg', rows=1, row_labels=None, column_labels=None):
    if not labels:
        labels = [''] * len(images)

    if not row_labels:
        row_labels = [''] * rows

    if not column_labels:
        column_labels = [''] * ((len(images) + rows - 1) // rows)

    # Validate the number of labels matches the number of images
    if len(labels) != len(images):
        print("Error: Number of labels should match the number of images.")
        return

    # Validate the number of row labels matches the number of rows
    if len(row_labels) != rows:
        print("Error: Number of row labels should match the number of rows.")
        return

    # Validate the number of column labels matches the number of columns
    cols = (len(images) + rows - 1) // rows
    if len(column_labels) != cols:
        print("Error: Number of column labels should match the number of columns.")
        return

    # Read the first image to determine dimensions
    first_image = cv2.imread(images[0])

    # Define canvas dimensions
    image_height = first_image.shape[0]
    image_width = first_image.shape[1]
    
    # Only add top_padding if we have labels other than column_labels
    has_other_labels = bool(labels and any(labels)) or bool(row_labels and any(row_labels))
    #top_padding = 80 if has_other_labels else (80 if row_labels else 80)
    top_padding = 50

    # Adjust canvas height to remove padding between rows when only column labels are present
    if not has_other_labels and column_labels:
        canvas_height = image_height * rows + top_padding  # Only add padding at the top
    else:
        canvas_height = (image_height + top_padding) * rows

    canvas_width = image_width * cols
    
    # Add extra width for row labels if they are provided
    left_padding = 40 if any(label.strip() for label in row_labels) else 0
    canvas_width += left_padding

    canvas = np.ones((canvas_height, canvas_width, 3), dtype=np.uint8) * 255

    # Add column labels (seed labels) at the top
    for i, label in enumerate(column_labels):
        x_pos = i * image_width + left_padding
        cv2.putText(canvas, label, (x_pos, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2, cv2.LINE_AA)

    # Iterate through images and labels
    for i, img_path in enumerate(images):
        image = cv2.imread(img_path)

        # Calculate the row and column indices for the current image
        row = i // cols
        col = i % cols

        # Add row label on the left side if provided
        if row_labels[row].strip():
            cv2.putText(canvas, row_labels[row], (0, row * (image_height + top_padding) + 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2, cv2.LINE_AA)

        # Add seed label centered on the image
        seed_label = column_labels[col]
        (text_width, text_height), _ = cv2.getTextSize(seed_label, cv2.FONT_HERSHEY_SIMPLEX, 1, 2)
        x_seed = col * image_width + left_padding + (image_width - text_width) // 2
        y_seed = row * (image_height + top_padding) + top_padding + (image_height - text_height) // 2
        cv2.putText(canvas, seed_label, (x_seed, y_seed), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2, cv2.LINE_AA)

        # Adjust y_start calculation to remove padding between rows when only column labels are present
        if not has_other_labels and column_labels:
            y_start = row * image_height + top_padding
            y_end = y_start + image_height
        else:
            y_start = row * (image_height + top_padding) + top_padding
            y_end = y_start + image_height

        x_start = col * image_width + left_padding
        x_end = x_start + image_width
        canvas[y_start:y_end, x_start:x_end, :] = image

    # Save the generated plot
    cv2.imwrite(output, canvas)
    print(f"Generated plot saved as {output}")

# Execute the function with command line arguments
save_image_plot(args.images, args.labels, args.output, args.rows, args.row_labels, args.column_labels)