Skip to content

Instantly share code, notes, and snippets.

@PatWie
Last active February 10, 2021 00:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save PatWie/b6e62847e8997f7f30f26ac50175d352 to your computer and use it in GitHub Desktop.
Save PatWie/b6e62847e8997f7f30f26ac50175d352 to your computer and use it in GitHub Desktop.
Training ResNet50 in TensorFlow 2.0
import tensorflow as tf
import numpy as np
BATCH = 2
NUM_EPOCHS = 25
"""
ResNet50 is implemented here:
https://github.com/tensorflow/tensorflow/blob/bd754067dac90182d883f621b775d76ec7c6b87d/tensorflow/python/eager/benchmarks/resnet50/resnet50.py#L1
"""
class FakeData(object):
def __init__(self):
super(FakeData, self).__init__()
self.length = 100
self.X_train = np.random.random((224, 224, 3)).astype('float32')
self.Y_train = np.array([np.random.randint(1000)]).astype('int32')
def __iter__(self):
for _ in range(self.length):
yield self.X_train, self.Y_train
def __len__(self):
return self.length
def output_shapes(self):
return (self.X_train.shape, self.Y_train.shape)
def output_types(self):
return (tf.float32, tf.int32)
df = FakeData()
def get_data(df):
tdf = tf.data.Dataset.from_generator(
generator=df.__iter__,
output_types=df.output_types(),
output_shapes=df.output_shapes())
tdf = tdf.batch(BATCH)
tdf = tdf.prefetch(tf.data.experimental.AUTOTUNE)
return tdf
# see https://www.tensorflow.org/guide/distributed_training#mirroredstrategy
mirrored_strategy = tf.distribute.MirroredStrategy()
with mirrored_strategy.scope():
model = tf.keras.applications.resnet.ResNet50(
input_shape=df.output_shapes()[0],
include_top=True,
weights=None)
model.compile(
optimizer=tf.keras.optimizers.SGD(lr=0.001),
loss='binary_crossentropy',
metrics=['accuracy'])
model.fit(get_data(df), epochs=NUM_EPOCHS)
@luqiang21
Copy link

Line 18 is not generating an one-hot label. I was able to run after changing it to one-hot variable.

zeros = np.zeros(1000)
zeros[np.random.randint(1000)] = 1
self.Y_train = zeros.astype('int32')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment