Spaces:
Running
Running
""" Sample TensorFlow XML-to-TFRecord converter | |
usage: generate_tfrecord.py [-h] [-x XML_DIR] [-l LABELS_PATH] [-o OUTPUT_PATH] [-i IMAGE_DIR] [-c CSV_PATH] | |
optional arguments: | |
-h, --help show this help message and exit | |
-x XML_DIR, --xml_dir XML_DIR | |
Path to the folder where the input .xml files are stored. | |
-l LABELS_PATH, --labels_path LABELS_PATH | |
Path to the labels (.pbtxt) file. | |
-o OUTPUT_PATH, --output_path OUTPUT_PATH | |
Path of output TFRecord (.record) file. | |
-i IMAGE_DIR, --image_dir IMAGE_DIR | |
Path to the folder where the input image files are stored. Defaults to the same directory as XML_DIR. | |
-c CSV_PATH, --csv_path CSV_PATH | |
Path of output .csv file. If none provided, then no file will be written. | |
""" | |
import os | |
import glob | |
import pandas as pd | |
import io | |
import xml.etree.ElementTree as ET | |
import argparse | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress TensorFlow logging (1) | |
import tensorflow.compat.v1 as tf | |
from PIL import Image | |
from object_detection.utils import dataset_util, label_map_util | |
from collections import namedtuple | |
# Initiate argument parser | |
parser = argparse.ArgumentParser( | |
description="Sample TensorFlow XML-to-TFRecord converter") | |
parser.add_argument("-x", | |
"--xml_dir", | |
help="Path to the folder where the input .xml files are stored.", | |
type=str) | |
parser.add_argument("-l", | |
"--labels_path", | |
help="Path to the labels (.pbtxt) file.", type=str) | |
parser.add_argument("-o", | |
"--output_path", | |
help="Path of output TFRecord (.record) file.", type=str) | |
parser.add_argument("-i", | |
"--image_dir", | |
help="Path to the folder where the input image files are stored. " | |
"Defaults to the same directory as XML_DIR.", | |
type=str, default=None) | |
parser.add_argument("-c", | |
"--csv_path", | |
help="Path of output .csv file. If none provided, then no file will be " | |
"written.", | |
type=str, default=None) | |
args = parser.parse_args() | |
if args.image_dir is None: | |
args.image_dir = args.xml_dir | |
label_map = label_map_util.load_labelmap(args.labels_path) | |
label_map_dict = label_map_util.get_label_map_dict(label_map) | |
def xml_to_csv(path): | |
"""Iterates through all .xml files (generated by labelImg) in a given directory and combines | |
them in a single Pandas dataframe. | |
Parameters: | |
---------- | |
path : str | |
The path containing the .xml files | |
Returns | |
------- | |
Pandas DataFrame | |
The produced dataframe | |
""" | |
xml_list = [] | |
for xml_file in glob.glob(path + '/*.xml'): | |
tree = ET.parse(xml_file) | |
root = tree.getroot() | |
for member in root.findall('object'): | |
value = (root.find('filename').text, | |
int(root.find('size')[0].text), | |
int(root.find('size')[1].text), | |
member[0].text, | |
int(member[4][0].text), | |
int(member[4][1].text), | |
int(member[4][2].text), | |
int(member[4][3].text) | |
) | |
xml_list.append(value) | |
column_name = ['filename', 'width', 'height', | |
'class', 'xmin', 'ymin', 'xmax', 'ymax'] | |
xml_df = pd.DataFrame(xml_list, columns=column_name) | |
return xml_df | |
def class_text_to_int(row_label): | |
return label_map_dict[row_label] | |
def split(df, group): | |
data = namedtuple('data', ['filename', 'object']) | |
gb = df.groupby(group) | |
return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)] | |
def create_tf_example(group, path): | |
with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid: | |
encoded_jpg = fid.read() | |
encoded_jpg_io = io.BytesIO(encoded_jpg) | |
image = Image.open(encoded_jpg_io) | |
width, height = image.size | |
filename = group.filename.encode('utf8') | |
image_format = b'jpg' | |
xmins = [] | |
xmaxs = [] | |
ymins = [] | |
ymaxs = [] | |
classes_text = [] | |
classes = [] | |
for index, row in group.object.iterrows(): | |
xmins.append(row['xmin'] / width) | |
xmaxs.append(row['xmax'] / width) | |
ymins.append(row['ymin'] / height) | |
ymaxs.append(row['ymax'] / height) | |
classes_text.append(row['class'].encode('utf8')) | |
classes.append(class_text_to_int(row['class'])) | |
tf_example = tf.train.Example(features=tf.train.Features(feature={ | |
'image/height': dataset_util.int64_feature(height), | |
'image/width': dataset_util.int64_feature(width), | |
'image/filename': dataset_util.bytes_feature(filename), | |
'image/source_id': dataset_util.bytes_feature(filename), | |
'image/encoded': dataset_util.bytes_feature(encoded_jpg), | |
'image/format': dataset_util.bytes_feature(image_format), | |
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins), | |
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs), | |
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins), | |
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs), | |
'image/object/class/text': dataset_util.bytes_list_feature(classes_text), | |
'image/object/class/label': dataset_util.int64_list_feature(classes), | |
})) | |
return tf_example | |
def main(_): | |
writer = tf.python_io.TFRecordWriter(args.output_path) | |
path = os.path.join(args.image_dir) | |
examples = xml_to_csv(args.xml_dir) | |
grouped = split(examples, 'filename') | |
for group in grouped: | |
tf_example = create_tf_example(group, path) | |
writer.write(tf_example.SerializeToString()) | |
writer.close() | |
print('Successfully created the TFRecord file: {}'.format(args.output_path)) | |
if args.csv_path is not None: | |
examples.to_csv(args.csv_path, index=None) | |
print('Successfully created the CSV file: {}'.format(args.csv_path)) | |
if __name__ == '__main__': | |
tf.app.run() | |