Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
A patch for training deeplabv3 on the ADE20K dataset
diff --git a/research/deeplab/datasets/build_ade20k_data.py b/research/deeplab/datasets/build_ade20k_data.py
new file mode 100644
index 0000000..7dcf79c
--- /dev/null
+++ b/research/deeplab/datasets/build_ade20k_data.py
@@ -0,0 +1,116 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+import glob
+import math
+import os
+import random
+import string
+import sys
+from PIL import Image
+import build_data
+import tensorflow as tf
+
+FLAGS = tf.app.flags.FLAGS
+flags = tf.app.flags
+
+tf.app.flags.DEFINE_string(
+ 'train_image_folder',
+ './ADEChallengeData2016/images/training',
+ 'Folder containing trainng images')
+tf.app.flags.DEFINE_string(
+ 'train_image_label_folder',
+ './ADEChallengeData2016/annotations/training',
+ 'Folder containing annotations for trainng images')
+
+tf.app.flags.DEFINE_string(
+ 'val_image_folder',
+ './ADE20K_2016_07_26/images/validation',
+ 'Folder containing validation images')
+
+tf.app.flags.DEFINE_string(
+ 'val_image_label_folder',
+ './ADE20K_2016_07_26/annotations/validation',
+ 'Folder containing annotations for validation')
+
+tf.app.flags.DEFINE_string(
+ 'output_dir', './tfrecord',
+ 'Path to save converted SSTable of Tensorflow example')
+
+_NUM_SHARDS = 4
+
+def _convert_dataset(dataset_split, dataset_dir, dataset_label_dir):
+ """ Convert the ADE20k dataset into into tfrecord format (SSTable).
+
+ Args:
+ dataset_split: dataset split (e.g., train, val)
+ dataset_dir: dir in which the dataset locates
+ dataset_label_dir: dir in which the annotations locates
+
+ Raises:
+ RuntimeError: If loaded image and label have different shape.
+ """
+
+ img_names = glob.glob(os.path.join(dataset_dir, '*.jpg'))
+ random.shuffle(img_names)
+ seg_names = []
+ for f in img_names:
+ # get the filename without the extension
+ basename = os.path.basename(f).split(".")[0]
+ # cover its corresponding *_seg.png
+ seg = os.path.join(dataset_label_dir, basename+'.png')
+ seg_names.append(seg)
+
+ num_images = len(img_names)
+ num_per_shard = int(math.ceil(num_images) / float(_NUM_SHARDS))
+
+ image_reader = build_data.ImageReader('jpeg', channels=3)
+ label_reader = build_data.ImageReader('png', channels=1)
+
+ for shard_id in range(_NUM_SHARDS):
+ output_filename = os.path.join(
+ FLAGS.output_dir,
+ '%s-%05d-of-%05d.tfrecord' % (dataset_split, shard_id, _NUM_SHARDS))
+ with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
+ start_idx = shard_id * num_per_shard
+ end_idx = min((shard_id + 1) * num_per_shard, num_images)
+ for i in range(start_idx, end_idx):
+ sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
+ i + 1, num_images, shard_id))
+ sys.stdout.flush()
+ # Read the image.
+ image_filename = img_names[i]
+ image_data = tf.gfile.FastGFile(image_filename, 'r').read()
+ height, width = image_reader.read_image_dims(image_data)
+ # Read the semantic segmentation annotation.
+ seg_filename = seg_names[i]
+ seg_data = tf.gfile.FastGFile(seg_filename, 'r').read()
+ seg_height, seg_width = label_reader.read_image_dims(seg_data)
+ if height != seg_height or width != seg_width:
+ raise RuntimeError('Shape mismatched between image and label.')
+ # Convert to tf example.
+ example = build_data.image_seg_to_tfexample(
+ image_data, img_names[i], height, width, seg_data)
+ tfrecord_writer.write(example.SerializeToString())
+ sys.stdout.write('\n')
+ sys.stdout.flush()
+
+def main(unused_argv):
+ tf.gfile.MakeDirs(FLAGS.output_dir)
+ _convert_dataset('train', FLAGS.train_image_folder, FLAGS.train_image_label_folder)
+ _convert_dataset('val', FLAGS.val_image_folder, FLAGS.val_image_label_folder)
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/research/deeplab/datasets/build_ade20k_eval_data.py b/research/deeplab/datasets/build_ade20k_eval_data.py
new file mode 100644
index 0000000..a1acead
--- /dev/null
+++ b/research/deeplab/datasets/build_ade20k_eval_data.py
@@ -0,0 +1,104 @@
+# Copyright 2018 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+import glob
+import math
+import os
+import random
+import string
+import sys
+from PIL import Image
+import build_data
+import tensorflow as tf
+
+FLAGS = tf.app.flags.FLAGS
+flags = tf.app.flags
+
+tf.app.flags.DEFINE_string(
+ 'eval_image_folder',
+ './ADE20K_2016_07_26/images/training',
+ 'Folder containing trainng images and segmentation annotations')
+
+tf.app.flags.DEFINE_string(
+ 'output_dir', './tfrecord',
+ 'Path to save converted SSTable of Tensorflow example')
+
+_NUM_SHARDS = 1
+
+def _convert_dataset(dataset_split, dataset_dir):
+ """ Convert the ADE20k dataset into into tfrecord format (SSTable).
+
+ Args:
+ dataset_split: dataset split (e.g., train, test)
+ dataset_dir: dir in which the dataset locates
+
+ Raises:
+ RuntimeError: If loaded image and label have different shape.
+ """
+
+ img_names = []
+ seg_names = []
+ for i in string.ascii_lowercase:
+ dir_name = os.path.join(dataset_dir, i)
+ img_names.extend(glob.glob(os.path.join(dir_name, '*/*.jpg')))
+ img_names.extend(glob.glob(os.path.join(dir_name, '*/*.jpeg')))
+ random.shuffle(img_names)
+ for f in img_names:
+ # get the filename without the extension
+ basename = os.path.basename(f).split(".")[0]
+ # cover its corresponding *_seg.png
+ seg = os.path.join(os.path.dirname(f), basename+'_seg.png')
+ seg_names.append(seg)
+
+ num_images = len(img_names)
+ num_per_shard = int(math.ceil(num_images) / float(_NUM_SHARDS))
+
+ image_reader = build_data.ImageReader('jpeg', channels=3)
+ label_reader = build_data.ImageReader('png', channels=1)
+
+ for shard_id in range(_NUM_SHARDS):
+ output_filename = os.path.join(
+ FLAGS.output_dir,
+ '%s-%05d-of-%05d.tfrecord' % (dataset_split, shard_id, _NUM_SHARDS))
+ with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
+ start_idx = shard_id * num_per_shard
+ end_idx = min((shard_id + 1) * num_per_shard, num_images)
+ for i in range(start_idx, end_idx):
+ sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
+ i + 1, num_images, shard_id))
+ sys.stdout.flush()
+ # Read the image.
+ image_filename = img_names[i]
+ image_data = tf.gfile.FastGFile(image_filename, 'r').read()
+ height, width = image_reader.read_image_dims(image_data)
+ # Read the semantic segmentation annotation.
+ seg_filename = seg_names[i]
+ seg_data = tf.gfile.FastGFile(seg_filename, 'r').read()
+ seg_height, seg_width = label_reader.read_image_dims(seg_data)
+ if height != seg_height or width != seg_width:
+ raise RuntimeError('Shape mismatched between image and label.')
+ # Convert to tf example.
+ example = build_data.image_seg_to_tfexample(
+ image_data, img_names[i], height, width, seg_data)
+ tfrecord_writer.write(example.SerializeToString())
+ sys.stdout.write('\n')
+ sys.stdout.flush()
+
+def main(unused_argv):
+ tf.gfile.MakeDirs(FLAGS.output_dir)
+ _convert_dataset('eval', FLAGS.eval_image_folder)
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/research/deeplab/datasets/segmentation_dataset.py b/research/deeplab/datasets/segmentation_dataset.py
index a777252..cedeb54 100644
--- a/research/deeplab/datasets/segmentation_dataset.py
+++ b/research/deeplab/datasets/segmentation_dataset.py
@@ -85,10 +85,24 @@ _PASCAL_VOC_SEG_INFORMATION = DatasetDescriptor(
ignore_label=255,
)
+# These number (i.e., 'train'/'test') seems to have to be hard coded
+# You are required to figure it out for your training/testing example.
+# Is there a way to automatically figure it out ?
+_ADE20K_INFORMATION = DatasetDescriptor(
+ splits_to_sizes = {
+ 'train': 20210, # num of samples in images/training
+ 'val': 2000, # num of samples in images/validation
+ 'eval': 2,
+ },
+ num_classes=150,
+ ignore_label=255,
+)
+
_DATASETS_INFORMATION = {
'cityscapes': _CITYSCAPES_INFORMATION,
'pascal_voc_seg': _PASCAL_VOC_SEG_INFORMATION,
+ 'ade20k': _ADE20K_INFORMATION,
}
# Default file pattern of TFRecord of TensorFlow Example.
diff --git a/research/deeplab/train_ade20k.sh b/research/deeplab/train_ade20k.sh
new file mode 100755
index 0000000..2005409
--- /dev/null
+++ b/research/deeplab/train_ade20k.sh
@@ -0,0 +1,38 @@
+#!/bin/bash
+
+# NOTE To use the ADE20K dataset, do
+# 1. set `--initialize_last_layer=False` so that we can have @num_classes==150
+# 2. set proper values for `--min_resize_value` (e.g., 350) and
+# `--max_resize_value` (e.g., 500) and let `--resize_factor equals to
+# `--output_stride`
+# 3. set @exclude_list in utils/train_utils.py to only include _LOGITS_SCOPE_NAME
+# 4. modify the setting of _ADE20K_INFORMATION in datasets/segmentation_dataset.py
+#
+# Note also that, as stated in the doc, if using a fine-tuned pretrained model,
+# one can set `fine_tune_batch_norm=False` and use a smaller batch size (you
+# will need a batch size >= 16 if you are going to fine-tune the BN parameters)
+
+
+python train.py \
+ --logtostderr \
+ --training_number_of_steps=130000 \
+ --train_split="train" \
+ --model_variant="xception_65" \
+ --astrous_rates=6 \
+ --astrous_rates=12 \
+ --astrous_rates=18 \
+ --output_stride=16 \
+ --decoder_output_stride=4 \
+ --train_crop_size=513 \
+ --train_crop_size=513 \
+ --train_batch_size=4 \
+ --min_resize_value=350 \
+ --max_resize_value=500 \
+ --resize_factor=16 \
+ --fine_tune_batch_norm=False \
+ --dataset="ade20k" \
+ --initialize_last_layer=False \
+ --tf_initial_checkpoint=/home/yubin/project/models/research/deeplab/deeplabv3_pascal_train_aug/model.ckpt \
+ --train_logdir=/home/yubin/project/models/research/deeplab/datasets/exp/ade20k/train \
+ --dataset_dir=/home/yubin/project/models/research/deeplab/datasets/ADE20K/tfrecord/
+
diff --git a/research/deeplab/utils/train_utils.py b/research/deeplab/utils/train_utils.py
index e2562bf..9bc9cbc 100644
--- a/research/deeplab/utils/train_utils.py
+++ b/research/deeplab/utils/train_utils.py
@@ -99,7 +99,8 @@ def get_model_init_fn(train_logdir,
tf.logging.info('Initializing model from path: %s', tf_initial_checkpoint)
# Variables that will not be restored.
- exclude_list = ['global_step']
+ #exclude_list = ['global_step']
+ exclude_list = ['logits'] # model._LOGITS_SCOPE_NAME
if not initialize_last_layer:
exclude_list.extend(last_layers)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.