Skip to content

Instantly share code, notes, and snippets.

@simoncozens
Last active July 5, 2020 10:52
Show Gist options
  • Save simoncozens/661155d14b243e94d638797654a43cbb to your computer and use it in GitHub Desktop.
Save simoncozens/661155d14b243e94d638797654a43cbb to your computer and use it in GitHub Desktop.
from tensorfont.dataset import prepare_training_data
from tensorfont.generators import RandomPair
prepare_training_data()
def kern_generator():
rpg = RandomPair(196, 196, 76, "training")
gen = rpg.generator()
while True:
font, left, right = next(gen)
img1, perturbation = rpg.get_image(font, left, right, perturbation_range= [-100,100])
img2, _ = rpg.get_image(f, "H", "H")
yield (img1.rgb(), img2.rgb()), abs(perturbation) < 50 * f.scale_factor
def mobilenet(prefix):
m = tf.keras.applications.MobileNetV2(include_top=False,
input_shape=(196, 196,3),
weights=None)
for l in m.layers: l._name = prefix+"_"+l.name
return (m.input, m.output)
g_t_kern = tf.data.Dataset.from_generator(
lambda : kern_generator(),
output_types=( (tf.int32, tf.int32), (tf.int32) ),
output_shapes=( ( (196, 196,3), (196, 196,3), ), () )
).batch(32)
input1, left = mobilenet("left")
input2, right = mobilenet("right")
merged = tf.keras.layers.Concatenate()([left,right])
x = tf.keras.layers.Flatten()(merged)
x = tf.keras.layers.Dense(1024, activation='selu')(x)
x = tf.keras.layers.Dense(256, activation='selu')(x)
output = tf.keras.layers.Dense(1, activation="sigmoid")(x)
kern_model = tf.keras.models.Model( inputs=[input1, input2], outputs=[output])
kern_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
kern_model.fit(g_t_kern, epochs=500, steps_per_epoch=1000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment