Skip to content

Instantly share code, notes, and snippets.

@elect000
Created November 29, 2018 10:06
Show Gist options
  • Save elect000/130acbdb0a3779910082593db4296254 to your computer and use it in GitHub Desktop.
Save elect000/130acbdb0a3779910082593db4296254 to your computer and use it in GitHub Desktop.
train-test
"""
config.yaml
training:
tfrecords:
- 'imgclassification/dataset/img_dataset_0_[len=600]_train.tfrecord'
- 'imgclassification/dataset/img_dataset_1_[len=600]_train.tfrecord'
length:
- 600
- 600
validation:
tfrecords:
- 'imgclassification/dataset/img_dataset_0_[len=30]_test.tfrecord'
- 'imgclassification/dataset/img_dataset_1_[len=30]_test.tfrecord'
length:
- 30
- 30
keep_prob: 0.8
train_dir: 'imgclassification/train_log'
num_threads: 4
"""
import math
import tensorflow as tf
import yaml
from tqdm import tqdm
def inference(images, config, num_class):
"""
:param images: Tensorflow's float tensor [batch_size x image_size x image_size x image_channel]
:param config
:param num_class 2 or some integer
:return:
"""
with tf.variable_scope('conv1') as scope:
conv = tf.layers.conv2d(
inputs=images,
filters=32,
kernel_size=[3, 3],
padding='SAME',
activation=tf.nn.relu
)
conv = tf.layers.conv2d(
inputs=conv,
filters=64,
kernel_size=[3, 3],
padding='SAME',
activation=tf.nn.relu
)
pool = tf.layers.max_pooling2d(conv, pool_size=[2, 2], strides=2, padding='SAME')
drop = tf.layers.dropout(pool, rate=config['keep_prob'], name=scope.name)
with tf.variable_scope('conv2') as scope:
conv = tf.layers.conv2d(
inputs=drop,
filters=128,
kernel_size=[3, 3],
padding='SAME',
activation=tf.nn.relu
)
pool = tf.layers.max_pooling2d(conv, pool_size=[2, 2], strides=2, padding='SAME')
conv = tf.layers.conv2d(
inputs=pool,
filters=128,
kernel_size=[2, 2],
padding='SAME',
activation=tf.nn.relu
)
pool = tf.layers.max_pooling2d(conv, pool_size=[2, 2], strides=2, padding='SAME')
drop = tf.layers.dropout(pool, rate=0.25, name=scope.name)
conv = tf.layers.conv2d(
inputs=drop,
filters=128,
kernel_size=[2, 2],
padding='SAME',
activation=tf.nn.relu
)
pool = tf.layers.max_pooling2d(conv, pool_size=[2, 2], strides=2, padding='SAME')
conv = tf.layers.conv2d(
inputs=pool,
filters=128,
kernel_size=[2, 2],
padding='SAME',
activation=tf.nn.relu
)
pool = tf.layers.max_pooling2d(conv, pool_size=[2, 2], strides=2, padding='SAME')
drop = tf.layers.dropout(pool, rate=0.25, name=scope.name)
print(drop.shape)
with tf.variable_scope('fully_connected') as scope:
flat = tf.reshape(drop, [-1, 1 * 1 * 128])
fc = tf.layers.dense(inputs=flat, units=1500, activation=tf.nn.relu)
drop = tf.layers.dropout(fc, rate=0.5)
softmax = tf.layers.dense(inputs=drop, units=num_class, activation=tf.nn.softmax, name=scope.name)
return softmax
def loss(logits, labels, weights):
"""
:param weights:
:param logits:
:param labels:
:return:
"""
class_weights = weights
weights = tf.reduce_sum(class_weights * labels, axis=1)
unweighted_losses = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=logits)
weighted_losses = unweighted_losses * weights
loss = tf.reduce_mean(weighted_losses)
tf.summary.scalar('cross_entropy', loss)
return loss
def training(loss, learning_rate):
"""
:param loss:
:param learning_rate:
:return:
"""
train_step = tf.train.AdamOptimizer(learning_rate=learning_rate,
epsilon=1e-08).minimize(loss)
return train_step
def accuracy(logits, labels):
"""
:param logits:
:param labels:
:return:
"""
current_prediction = tf.equal(tf.argmax(logits, axis=1), tf.argmax(labels, axis=1))
accuracy = tf.reduce_mean(tf.cast(current_prediction, tf.float32))
return accuracy
def get_tfrecord_serialized(tfrecord_path):
reader = tf.TFRecordReader()
tfrecord_file_queue = tf.train.string_input_producer([tfrecord_path], name='queue')
_, tfrecord_serialized = reader.read(tfrecord_file_queue)
return tfrecord_serialized
def parse_records(dataset):
features = {
'label': tf.FixedLenFeature((), tf.string),
'image': tf.FixedLenFeature([], tf.string)
}
parsed_features = tf.parse_single_example(dataset, features=features)
return parsed_features['label'], parsed_features['image']
def _read_images(root_config):
def read_images(label, image):
label = tf.decode_raw(label, tf.float32)
label = tf.reshape(label, shape=[root_config['Model']['num_class']])
image = tf.decode_raw(image, tf.float32)
image = tf.reshape(image,
shape=[root_config['Image']['image_size'],
root_config['Image']['image_size'],
root_config['Image']['image_channel']])
return image, label
return read_images
def create_dataset_iterator(config, root_config):
read_image = _read_images(root_config)
training_dataset = tf.data.TFRecordDataset(config['training']['tfrecords'])
training_dataset = training_dataset.map(parse_records, config['num_threads'])
training_dataset = training_dataset.map(read_image, config['num_threads'])
training_dataset = training_dataset \
.batch(root_config['Model']['batch_size']) \
.shuffle(sum(config['training']['length'])) \
.repeat(root_config['Model']['epoch'])
training_iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
training_dataset.output_shapes)
validation_dataset = tf.data.TFRecordDataset(config['validation']['tfrecords'])
validation_dataset = validation_dataset.map(parse_records, config['num_threads'])
validation_dataset = validation_dataset.map(read_image, config['num_threads'])
validation_dataset = validation_dataset \
.batch(sum(config['validation']['length'])) \
.shuffle(sum(config['validation']['length'])) \
.repeat(-1)
validation_iterator = tf.data.Iterator.from_structure(validation_dataset.output_types,
validation_dataset.output_shapes)
return training_dataset, training_iterator, validation_dataset, validation_iterator
def main():
with open('imgclassification/model/config.yaml', 'r', encoding='utf-8') as yml:
config = yaml.load(yml)
with open('imgclassification/config.yaml', 'r', encoding='utf-8') as yml:
root_config = yaml.load(yml)
training_dataset, training_iterator, validation_dataset, validation_iterator = create_dataset_iterator(config,
root_config)
train_init_op = training_iterator.make_initializer(training_dataset)
valid_init_op = validation_iterator.make_initializer(validation_dataset)
training_batch = training_iterator.get_next()
validation_batch = validation_iterator.get_next()
images_placeholder = tf.placeholder(tf.float32, shape=(None,
root_config['Image']['image_size'],
root_config['Image']['image_size'],
root_config['Image']['image_channel']))
labels_placeholder = tf.placeholder(tf.float32, shape=(None,
root_config['Model']['num_class']))
weights = tf.constant([1.0, 1.0])
logits = inference(images_placeholder, config, root_config['Model']['num_class'])
loss_value = loss(logits, labels_placeholder, weights=weights)
acc = accuracy(logits, labels_placeholder)
learning_rate = 1e-4
with tf.name_scope('train'):
train_op = training(loss_value, learning_rate)
acc_summary_train_op = tf.summary.scalar('train_acc', acc)
loss_summary_train_op = tf.summary.scalar('train_loss', loss_value)
with tf.name_scope('valudation'):
acc_summary_val_op = tf.summary.scalar('val_acc', acc)
loss_summary_val_op = tf.summary.scalar('val_loss', loss_value)
# summary_op = tf.summary.merge_all()
print('[INFO]: CREATE SESSION')
with tf.Session() as sess:
sess.run(train_init_op)
sess.run(valid_init_op)
summary_writer = tf.summary.FileWriter(config['train_dir'], sess.graph)
sess.run(tf.global_variables_initializer())
step = 0
while True:
try:
step += 1
for i in tqdm(
range(math.floor(sum(config['training']['length']) / root_config['Model']['batch_size']) - 1)):
images, labels = sess.run(training_batch)
sess.run(train_op, feed_dict={
images_placeholder: images,
labels_placeholder: labels
})
images, labels = sess.run(training_batch)
res = sess.run([acc_summary_train_op, loss_summary_train_op], feed_dict={
images_placeholder: images,
labels_placeholder: labels,
})
for j in range(len(res)):
summary_writer.add_summary(res[j], step)
images, labels = sess.run(validation_batch)
res = sess.run([acc_summary_val_op, loss_summary_val_op], feed_dict={
images_placeholder: images,
labels_placeholder: labels,
})
for j in range(len(res)):
summary_writer.add_summary(res[j], step)
except tf.errors.OutOfRangeError:
break
print("[INFO] TRAINING FINISH")
# saver.save(sess, 'imgclassification/' + 'model.ckpt')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment