Skip to content

Instantly share code, notes, and snippets.

@jamesonthecrow
Last active May 12, 2020 06:26
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jamesonthecrow/0ed0ab137eaec889d22dda237ee735b7 to your computer and use it in GitHub Desktop.
Save jamesonthecrow/0ed0ab137eaec889d22dda237ee735b7 to your computer and use it in GitHub Desktop.
Train a Fritz Style Transfer model with a custom style image in 20 minutes. Create Core ML and TensorFlow Mobile versions for use in your app. More information at https://fritz.ai.
# A script to train an artistic style transfer model from a custom style image.
# A Google Colab going through the same steps can be found here:
# https://colab.research.google.com/drive/1nDkxLKBgZGFscGoF0tfyPMGqW03xITl0#scrollTo=V33xVH-CWUCs
# Note that this script will download and unzip 1GB of photos for training.
# Make sure you have the appropriate permissions to use any images.
# CHANGE ME BEFORE RUNNING
STYLE_IMAGE_URL='STYLE_IMAGE_URL'
# Install requirements
pip install keras==2.2.4 tensorflow numpy matplotlib
pip install git+https://www.github.com/keras-team/keras-contrib.git
pip install git+https://github.com/apple/coremltools.git@master
# Clone the style transfer repository
git clone https://github.com/fritzlabs/fritz-models.git
cd fritz-models/style_transfer
# Add this directory to your python path
export PYTHONPATH=$PYTHONPATH:`pwd`
# Create a data directly. This is ignored by git
mkdir data/
# Download the style image
wget -O data/style_image.jpg $STYLE_IMAGE_URL
# Download and unzip the training images from MSCOCO if we haven't already
if [ ! -d data/val2017 ]; then
wget -O data/val2017.zip http://images.cocodataset.org/zips/val2017.zip
unzip -d data/ data/val2017.zip
fi
# Convert images to TFRecord format
if [ ! -f data/training_images.tfrecord ]; then
python create_training_dataset.py \
--output data/training_images.tfrecord \
--image-dir data/val2017/
fi
# Train the model
python style_transfer/train.py \
--training-image-dset data/training_images.tfrecord \
--style-images data/style_image.jpg \
--model-checkpoint data/my_style_025.h5 \
--image-size 256,256 \
--alpha 0.25 \
--num-iterations 2 \
--batch-size 2 \
--fine-tune-checkpoint example/starry_night_256x256_025.h5
# Convert to Core ML
python convert_to_coreml.py \
--keras-checkpoint data/my_style_025.h5 \
--alpha 0.25 \
--image-size 640,640 \
--coreml-model data/my_style_025.mlmodel
# Convert to TensorFlow Mobile
python convert_to_tfmobile.py \
--keras-checkpoint data/my_style_025.h5 \
--alpha 0.25 \
--image-size 640,640 \
--output-dir data/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment