Skip to content

Instantly share code, notes, and snippets.

@Kaapp
Last active April 30, 2018 16:39
Show Gist options
  • Save Kaapp/abdb54b232eb7f07b87955d9a18df57d to your computer and use it in GitHub Desktop.
Save Kaapp/abdb54b232eb7f07b87955d9a18df57d to your computer and use it in GitHub Desktop.
import csv
import hashlib
import re
import numpy as np
import tensorflow as tf
from PIL import Image
from tensorflow.python.util import compat
class DataHandler:
# 112120 images
# 70% training, 10% validation, 20% testing
# ~78484 training, ~11212 validation, ~22424 training
# 75712 training, 10812 validation, 25596 training <- actual splits.
def __init__(self, multi_label=True):
self.MAX_NUM_IMAGES_PER_CLASS = 2 ** 27 - 1 #~134M
self.TOTAL_IMAGES = 112120
self.training_percentage = 70
self.validation_percentage = 10
self.testing_percentage = 20
self.multi_label = multi_label
if multi_label:
self.GROUND_TRUTHS = ['Cardiomegaly','Emphysema','Effusion','Hernia','Infiltration',
'Mass','Nodule','Atelectasis','Pneumothorax','Pleural_Thickening',
'Pneumonia','Fibrosis','Edema','Consolidation']
self.image_list = self.create_multilabel_label_dict()
else:
self.GROUND_TRUTHS = ['Pathology', 'No Pathology']
self.image_list = self.create_singlelabel_label_dict()
return None
def create_multilabel_label_dict(self):
'''
1. create mapping filename -> dataset using the txt file so x = { "001.png": "testing", etc } O(n)
2. create normal list by iterating the csv line by line but check mapping to tell which data set. O(n)
3. for train/val set we need to hash to get approx split. -> O(2n) creation.
'''
image_list = {
'training': [],
'validation': [],
'testing': []
}
file_mapping = {}
with open('./train_val_list.txt') as file:
train_files = file.read().splitlines()
for file_name in train_files:
file_mapping[file_name] = 1
with open('./test_list.txt') as file:
test_files = file.read().splitlines()
for file_name in test_files:
file_mapping[file_name] = 0
first_line = True
with open('../data/Data_Entry_2017.csv', 'r') as csvfile:
reader = csv.reader(csvfile)
for row in reader:
if first_line:
first_line = False
continue
# row[0] = filename
# row[1] = ground truths
file_name = row[0]
try:
if file_mapping[file_name] == 1:
# Train/validation set, need to hash to split
percentage_hash = self.get_percentage_hash(row[0])
if percentage_hash < 12.5: # 10% of total data is 12.5% of remaining data
image_list['validation'].append((file_name, self.new_y_array(row[1])))
else:
image_list['training'].append((file_name, self.new_y_array(row[1])))
else:
image_list['testing'].append((file_name, self.new_y_array(row[1])))
except KeyError:
pass
return image_list
def create_singlelabel_label_dict(self):
return []
def get_percentage_hash(self, file_name):
# Hash only the patient number so that multiple images from the same patient
# compute the same hash so they will be placed in the same subset.
file_name = re.sub("_[0-9]{3}\.png", "", file_name)
file_name_hashed = hashlib.sha1(compat.as_bytes(file_name)).hexdigest()
percentage_hash = ((int(file_name_hashed, 16) %
(self.MAX_NUM_IMAGES_PER_CLASS + 1)) *
(100.0 / self.MAX_NUM_IMAGES_PER_CLASS))
return percentage_hash
def new_y_array(self, truth_string):
array = np.zeros(len(self.GROUND_TRUTHS), dtype=np.float32)
if self.multi_label:
labels_array = truth_string.split('|')
for label in labels_array:
try:
label_index = self.GROUND_TRUTHS.index(label)
array[label_index] = 1
except ValueError:
pass #do nothing, it's No Finding which we encode as all zeros
return array
def image_parse_function(self, filename, label):
image_string = tf.read_file('../data/images/multi-label/' + filename)
image_decoded = tf.image.decode_png(image_string, channels=1)
image_resized = tf.image.resize_images(image_decoded, [256,256])
image_cropped = tf.image.crop_to_bounding_box(image_resized, 16, 16, 224, 224)
return image_cropped, label
def get_dataset(self, data_type='training', num_examples=0):
if num_examples < 0:
raise ValueError('Invalid num_examples: %d' % num_examples)
size = len(self.image_list[data_type])
features = []
labels = []
if num_examples == 0 or num_examples >= size:
for feature, label in self.image_list[data_type]:
features.append(feature)
labels.append(label)
else:
for index in range(num_examples):
feature, label = self.image_list[data_type][index]
features.append(feature)
labels.append(label)
return features, labels
def get_pathology_counts(self, data_type='validation'):
image_dict = {}
pathology_dict = {
'multi-label': []
}
with open('./' + data_type + '_images.txt') as file:
images = file.read().splitlines()
for image in images:
image_dict[image] = 1
with open('../data/Data_Entry_2017.csv') as file:
first_line = True
reader = csv.reader(file)
for row in reader:
if first_line:
first_line = False
continue
# row[0] = filename
# row[1] = ground truths
if row[0] in image_dict:
labels = row[1].split('|')
if len(labels) > 1:
pathology_dict['multi-label'].append(row[0])
else:
if labels[0] not in pathology_dict:
pathology_dict[labels[0]] = []
pathology_dict[labels[0]].append(row[0])
return pathology_dict
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import numpy as np
# See: https://arxiv.org/pdf/1409.4842.pdf
class GoogLeNet:
def __init__(self):
self.num_labels = 14
self.NAME = "GoogLeNet"
#training params
self.batch_size = 64
self.learning_rate = 0.1
self.weight_decay = 0.001
return None
def construct_graph(self, x, y):
self.graph = tf.get_default_graph()
self.lr = tf.placeholder(tf.float32, shape=[], name='LR')
self.keep_prob = tf.placeholder(tf.float32, shape=[], name='keep_prob')
self.is_training = tf.placeholder(tf.bool, shape=[], name='is_training')
model = self.conv(x, filters=64, kernel_size=7, stride=2, name='conv1_k7_s2')
model = self.max_pool(model, pool_size=3, stride=2, name="maxpool1_p3_s2")
model = tf.nn.local_response_normalization(input=model, alpha=0.0001, beta=0.75)
model = self.conv(model, filters=64, kernel_size=1, stride=1, name='conv2_k1_s1')
model = self.conv(model, filters=192, kernel_size=3, stride=1, name='conv2_k3_s1')
model = tf.nn.local_response_normalization(input=model, alpha=0.0001, beta=0.75)
model = self.max_pool(model, pool_size=3, stride=2, name='maxpool2_p3_s2')
model = self._inception_module(model, filters=[64, 96, 128, 16, 32, 32],
name='inception3a')
model = self._inception_module(model, filters=[128, 128, 192, 32, 96, 64],
name='inception3b')
model = self.max_pool(model, pool_size=3, stride=2, name='maxpool3_p3_s2')
model = self._inception_module(model, filters=[192, 96, 208, 16, 48, 64],
name='inception4a')
model = self._inception_module(model, filters=[160, 112, 224, 24, 64, 64],
name='inception4b')
model = self._inception_module(model, filters=[128, 128, 256, 24, 64, 64],
name='inception4c')
model = self._inception_module(model, filters=[112, 144, 288, 32, 64, 64],
name='inception4d')
model = self._inception_module(model, filters=[256, 160, 320, 32, 128, 128],
name='inception4e')
model = self.max_pool(model, pool_size=3, stride=2, name='maxpool4_p3_s2')
model = self._inception_module(model, filters=[256, 160, 320, 32, 128, 128],
name='inception5a')
model = self._inception_module(model, filters=[384, 192, 384, 48, 128, 128],
name='inception5b')
model = self.avg_pool(model, pool_size=7, stride=1, name='avgpool5_p7_s1')
#model = tf.reshape(model, [-1, 7 * 7 * 1024])
logits = self.fully_connected(model)
self.ys_pred = tf.nn.sigmoid(logits, name='prediction')
with tf.name_scope('loss'):
total_labels = tf.to_float(tf.multiply(self.batch_size, self.num_labels))
num_positive_labels = tf.count_nonzero(y, dtype=tf.float32)
num_negative_labels = tf.subtract(total_labels, num_positive_labels)
Bp = tf.divide(total_labels, num_positive_labels)
Bn = tf.divide(total_labels, num_negative_labels)
cross_entropy = -tf.reduce_sum((tf.multiply(Bp, y * tf.log(self.ys_pred + 1e-9))) +
(tf.multiply(Bn, (1-y) * tf.log(1-self.ys_pred + 1e-9))),
name="cross_entropy")
l2 = tf.add_n([tf.nn.l2_loss(var) for var in tf.trainable_variables()])
# The loss function
self.loss = cross_entropy + l2 * self.weight_decay
# Training the network with Adam using standard parameters.
#self.train_step = tf.train.AdamOptimizer(
# learning_rate=self.lr,
# beta1=0.9,
# beta2=0.999).minimize(self.loss)
self.train_step = tf.train.AdagradOptimizer(learning_rate=self.lr).minimize(self.loss)
# Define some wrapper functions for brevity/readability
def conv(self, inputs, filters, kernel_size, stride, name, padding='SAME',
activation=tf.nn.relu):
return tf.layers.conv2d(
inputs=inputs,
filters=filters,
kernel_size=[kernel_size, kernel_size],
strides=stride,
padding=padding,
activation=activation,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.001),
name=name)
def max_pool(self, inputs, pool_size, stride, name):
return tf.layers.max_pooling2d(
inputs=inputs,
pool_size=[pool_size, pool_size],
strides=stride,
padding='SAME',
name=name)
def avg_pool(self, inputs, pool_size, stride, name):
return tf.layers.average_pooling2d(
inputs=inputs,
pool_size=[pool_size, pool_size],
strides=stride,
padding='VALID',
name=name)
def fully_connected(self, inputs):
dropout = tf.layers.dropout(inputs, rate=1 - self.keep_prob, training=self.is_training)
# Need to reshape dropout to 2D tensor for FC layer, multiply the dimensions excluding
# batch size
new_shape = int(np.prod(self._get_tensor_shape(dropout)[1:]))
dropout = tf.reshape(dropout, [-1, new_shape])
return tf.layers.dense(dropout, self.num_labels)
def _get_tensor_shape(self, tensor):
return tensor.get_shape().as_list()
def _inception_module(self, inputs, filters, name):
if len(filters) != 6:
raise ValueError('Invalid filters input')
# From left to right in the graph @ https://arxiv.org/pdf/1409.4842.pdf fig.3
with tf.name_scope(name):
conv1_k1_s1 = self.conv(inputs, filters=filters[0], kernel_size=1, stride=1,
name=name + '_conv1_k1_s1')
conv2_k1_s1 = self.conv(inputs, filters=filters[1], kernel_size=1, stride=1,
name=name + '_conv2_k1_s1')
conv3_k3_s1 = self.conv(conv2_k1_s1, filters=filters[2], kernel_size=3, stride=1,
name=name + '_conv3_k3_s1')
conv4_k1_s1 = self.conv(inputs, filters=filters[3], kernel_size=1, stride=1,
name=name + '_conv4_k1_s1')
conv5_k5_s1 = self.conv(conv4_k1_s1, filters=filters[4], kernel_size=5, stride=1,
name=name + '_conv5_k5_s1')
pool1_p3_s1 = self.max_pool(inputs, pool_size=3, stride=1, name=name + '_pool1_p3_s1')
conv6_k1_s1 = self.conv(pool1_p3_s1, filters=filters[5], kernel_size=1, stride=1,
name=name + '_conv6_k1_s1')
tensor_list = [conv1_k1_s1, conv3_k3_s1, conv5_k5_s1, conv6_k1_s1]
return tf.concat(tensor_list, axis=3, name=name + '_merge')
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import numpy as np
class NaiveCNN:
def __init__(self):
self.num_labels = 14
self.NAME = "NaiveCNN"
#training params
self.batch_size = 64
self.learning_rate = 0.001
self.weight_decay = 0.0001
return None
def construct_graph(self, x, y):
self.graph = tf.get_default_graph()
# 224x224xGrayscale input data, cropped from 256x256 8bit greyscale PNG
#self.xs = tf.placeholder(tf.float32, shape=[None, 224, 224, 1])
# 14 possible pathologies
#self.ys = tf.placeholder(tf.float32, shape=[None, self.num_labels])
self.lr = tf.placeholder(tf.float32, shape=[])
self.keep_prob = tf.placeholder(tf.float32, shape=[])
self.is_training = tf.placeholder(tf.bool, shape=[])
model = tf.layers.conv2d(
inputs=x,
filters=64, # number of outputs
kernel_size=[7,7],
strides=2,
padding='SAME',
activation=tf.nn.relu,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.001),
name="conv_1_7_2") # conv num 1, size 7, stride 2
model = tf.layers.max_pooling2d(
inputs=model,
pool_size=[3,3],
strides=2,
name="pool_1_3_2") # pool num 1, size 3, stride 2
model = tf.layers.conv2d(
inputs=model,
filters=64,
kernel_size=[3,3],
strides=1,
padding='SAME',
kernel_initializer=tf.truncated_normal_initializer(stddev=0.001),
activation=tf.nn.relu,
name="conv_2_3_1") # conv num 2, size 3, stride 1
model = tf.layers.max_pooling2d(
inputs=model,
pool_size=[3,3],
strides=2,
name="pool_2_3_2")
# flatten input tensor before dense layer
model = tf.reshape(model, [-1, 27 * 27 * 64])
model = tf.layers.dense(inputs=model, units=1024, activation=tf.nn.relu)
model = tf.layers.dropout(inputs=model, rate=1 - self.keep_prob, training=self.is_training)
model = tf.layers.dense(inputs=model, units=self.num_labels)
# The layer used to get predictions from the network
# We will use this to calculate AUROC in testing
self.ys_pred = tf.nn.sigmoid(model, name="prediction")
# OLD
#with tf.name_scope('loss'):
# cross_entropy = -tf.reduce_sum((y * tf.log(self.ys_pred + 1e-9)) + ((1-y) * tf.log(1-self.ys_pred + 1e-9)), name="cross_entropy")
# l2 = tf.add_n([tf.nn.l2_loss(var) for var in tf.trainable_variables()])
# # The loss function, an element-wise sigmoid non-linearity
# self.loss = cross_entropy + l2 * self.weight_decay
with tf.name_scope('loss'):
total_labels = tf.to_float(tf.multiply(self.batch_size, self.num_labels))
num_positive_labels = tf.count_nonzero(y, dtype=tf.float32)
num_negative_labels = tf.subtract(total_labels, num_positive_labels)
Bp = tf.divide(total_labels, num_positive_labels)
Bn = tf.divide(total_labels, num_negative_labels)
cross_entropy = -tf.reduce_sum((tf.multiply(Bp, y * tf.log(self.ys_pred + 1e-9))) +
(tf.multiply(Bn, (1-y) * tf.log(1-self.ys_pred + 1e-9))),
name="cross_entropy")
l2 = tf.add_n([tf.nn.l2_loss(var) for var in tf.trainable_variables()])
# The loss function
self.loss = cross_entropy + l2 * self.weight_decay
# Training the network with Adam using standard parameters.
#self.train_step = tf.train.AdamOptimizer(
# learning_rate=self.lr,
# beta1=0.9,
# beta2=0.999).minimize(self.loss)
self.train_step = tf.train.AdagradOptimizer(learning_rate=self.lr).minimize(self.loss)
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
#import tensorflow as tf
import numpy as np
import GoogLeNet
import DataHandler
import tensorflow as tf
import os.path
import re
NUM_EPOCHS = 30
VALIDATION_SET_SIZE = 10000
def get_num_trainable_params():
total_parameters = 0
for variable in tf.trainable_variables():
shape = variable.get_shape()
variable_parametes = 1
for dim in shape:
variable_parametes *= dim.value
total_parameters += variable_parametes
return total_parameters
def add_summary_ops(ground_truth):
# We will round our networks predictions such that >50% presence is a positive, <=50% presence is negative
thresholds = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]
# For inference, we will display the actual percentages
p, _ = tf.metrics.precision_at_thresholds(labels=ground_truth, predictions=network.ys_pred, thresholds=thresholds)
r, _ = tf.metrics.recall_at_thresholds(labels=ground_truth, predictions=network.ys_pred, thresholds=thresholds)
# Using F1 because false negative and false positive are equally bad in medicine
precision = tf.reduce_mean(p)
recall = tf.reduce_mean(r)
f1 = 2 * precision * recall / (precision + recall)
with tf.name_scope("summaries"):
tf.summary.scalar("loss", network.loss)
# Plotting learning rate forces us to feed learning rate even when we don't train.
tf.summary.scalar("learning_rate", network.lr)
tf.summary.scalar("precision", precision)
tf.summary.scalar("recall", recall)
tf.summary.scalar("f1_score", f1)
network.summary_op = tf.summary.merge_all()
return p, r, f1
# Initialise network values
network = GoogLeNet.GoogLeNet()
# Get our list of files and their labels, and create our placeholders to feed
data = DataHandler.DataHandler()
train_features, train_labels = data.get_dataset('training')
val_features, val_labels = data.get_dataset('validation')
VALIDATION_SET_SIZE = len(val_features)
features_placeholder = tf.placeholder(tf.string, shape=[None])
labels_placeholder = tf.placeholder(tf.float32, shape=[None, len(data.GROUND_TRUTHS)])
# Create a dataset from our placeholders
dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
# Map the filenames to the actual image data
dataset = dataset.map(data.image_parse_function)
# Split the dataset into batches depending on the network's specified batch size.
dataset = dataset.batch(network.batch_size)
# Create an iterator for our datasets
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, dataset.output_types, dataset.output_shapes)
train_iterator = dataset.make_initializable_iterator()
val_iterator = dataset.make_initializable_iterator()
# Get our final image data and label from the iterator, pass it to the network and let
# the network build it's graph, followed by the summary ops
(x, y) = iterator.get_next()
network.construct_graph(x, y)
p, r, f1 = add_summary_ops(y)
# Create our summary file writer so we can track our progress on TensorBoard
train_writer = tf.summary.FileWriter('./train_logs/' + network.NAME + '/train', network.graph)
val_writer = tf.summary.FileWriter('./train_logs/' + network.NAME + '/val', network.graph)
# Start a session
with tf.Session(graph=network.graph) as sess:
# Create a saver so we can save/load model checkpoints after epochs
saver = tf.train.Saver()
batches_completed = 0
epochs_completed = 0
# Look for existing ckpt file else initialise!
available_ckpts = [int(re.match(r"(?:[a-zA-Z]*_)([0-9]*)(?:\.ckpt\.txt)", f).group(1))
for f in os.listdir('./checkpoints/' + network.NAME + '/')
if f.endswith('.ckpt.txt')]
if len(available_ckpts) > 0:
# Sort the list of checkpoint numbers in descending order so first entry is latest
available_ckpts.sort(reverse=True)
print('Restoring from epoch {0}'.format(available_ckpts[0]))
saver.restore(sess, './checkpoints/{0}/{0}_{1}.ckpt'.format(network.NAME, available_ckpts[0]))
# load epoch and batch values from old model
with open('./checkpoints/{0}/{0}_{1}.ckpt.txt'.format(network.NAME, available_ckpts[0])) as info_file:
values = info_file.read().splitlines()
if len(values) == 4:
batches_completed = int(values[1])
epochs_completed = int(values[3])
else:
# Initialise our global vars (W and b)
sess.run(tf.global_variables_initializer())
# Initialise our local vars (for calculating training/validation precision/recall/f1)
sess.run(tf.local_variables_initializer())
# Print the current models number of training params
print("Total training params: %.1fM" % (get_num_trainable_params() / 1e6))
# Get the iterator handles to feed for train/val/test
train_handle = sess.run(train_iterator.string_handle())
val_handle = sess.run(val_iterator.string_handle())
#for each batch --- learning rate drops to 0.01 at 150 epoch and 0.001 at 225 epoch?
no_improvement_last_epoch = False
old_loss = 2**32 - 1 # A large number in case this is our first run
# Compute for NUM_EPOCHS
while epochs_completed < NUM_EPOCHS:
# Initialise our iterators with data (this also resets them to the beginning of their dataset)
sess.run(train_iterator.initializer, feed_dict={features_placeholder: train_features, labels_placeholder: train_labels})
sess.run(val_iterator.initializer, feed_dict={features_placeholder: val_features, labels_placeholder: val_labels})
while True:
try:
# Every 1000 batches, also trace runtime statistics for debugging memory usage/compute time
if batches_completed % 1000 == 0:
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
_, loss, prediction, summary, _x, _y, _p, _r, _f1 = sess.run([network.train_step, network.loss, network.ys_pred, network.summary_op, x, y, p, r, f1 ],
feed_dict={
handle: train_handle,
network.lr: network.learning_rate,
network.is_training: True,
network.keep_prob: 0.8
},
options=run_options,
run_metadata=run_metadata)
train_writer.add_run_metadata(run_metadata, 'batch{0}'.format(batches_completed))
train_writer.add_summary(summary, global_step=batches_completed)
# else just train normally
else:
_, loss, prediction, summary, _x, _y, _p, _r, _f1 = sess.run([network.train_step, network.loss, network.ys_pred, network.summary_op, x, y, p, r, f1 ],
feed_dict={
handle: train_handle,
network.lr: network.learning_rate,
network.is_training: True,
network.keep_prob: 0.8
})
train_writer.add_summary(summary, global_step=batches_completed)
# Also run a validation batch every 20 batches for TensorBoard
if batches_completed % 20 == 0:
loss, prediction, summary, _x, _y, _p, _r, _f1 = sess.run([network.loss, network.ys_pred, network.summary_op, x, y, p, r, f1 ],
feed_dict={
handle: val_handle,
network.lr: network.learning_rate,
network.is_training: False,
network.keep_prob: 1.0
})
val_writer.add_summary(summary, global_step=batches_completed)
batches_completed = batches_completed + 1
# If we ran out of data, that's the end of our epoch
except tf.errors.OutOfRangeError:
break
# After our epoch, calculate mean loss over full validation set
sess.run(val_iterator.initializer, feed_dict={ features_placeholder: val_features, labels_placeholder: val_labels })
total_loss = 0
while True:
try:
loss, _preds, _y, _x, _p, _r, _f1 = sess.run([network.loss, network.ys_pred, y, x, p, r, f1],
feed_dict={
handle: val_handle,
network.lr: network.learning_rate,
network.is_training: False,
network.keep_prob: 1.0
})
total_loss += loss
# run predictions until validation set is exhausted
except tf.errors.OutOfRangeError:
break
# Compare the test to the previous models test, either drop learning rate or stop early if no improvement
mean_loss = total_loss / VALIDATION_SET_SIZE
try:
# Try to read old loss from previous checkpoint
with open('./checkpoints/' + network.NAME + '/' + network.NAME + '_%d.ckpt.txt' % (epochs_completed - 1), mode='r') as file:
data = file.read().splitlines()
old_loss = float(data[0])
old_learning_rate = float(data[2])
except:
# Must be first checkpoint
pass
# If we didn't improve
if mean_loss >= old_loss:
# and we just dropped the learning rate last epoch
if no_improvement_last_epoch:
# Stop training early
print("We're done! Best model was after {0} epochs at {1} mean loss.".format((epochs_completed - 2), old_loss))
break
else: # Decay learning rate by factor of 10, and take the previous weights
network.learning_rate = network.learning_rate * 0.1
saver.restore(sess, './checkpoints/' + network.NAME + '/' + network.NAME + '_%d.ckpt' % (epochs_completed - 1))
mean_loss = old_loss
# If we still don't improve next time after lowering learning rate
no_improvement_last_epoch = True
else:
no_improvement_last_epoch = False
# Save this model as a new checkpoint
file_name = './checkpoints/' + network.NAME + '/' + network.NAME + '_%d.ckpt' % epochs_completed
save_path = saver.save(sess, file_name)
# also save current learning rate and global step in an associated text file!
with open('./checkpoints/' + network.NAME + '/' + network.NAME + '_%d.ckpt.txt' % epochs_completed, mode='w') as out_file:
out_file.write('{0}\n{1}\n{2}\n{3}'.format(mean_loss, batches_completed, network.learning_rate, epochs_completed))
epochs_completed = epochs_completed + 1
train_writer.close()
val_writer.close()
@Engineero
Copy link

Really nice implementation!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment