Last active
May 12, 2020 06:26
-
-
Save jamesonthecrow/0ed0ab137eaec889d22dda237ee735b7 to your computer and use it in GitHub Desktop.
Train a Fritz Style Transfer model with a custom style image in 20 minutes. Create Core ML and TensorFlow Mobile versions for use in your app. More information at https://fritz.ai.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# A script to train an artistic style transfer model from a custom style image. | |
# A Google Colab going through the same steps can be found here: | |
# https://colab.research.google.com/drive/1nDkxLKBgZGFscGoF0tfyPMGqW03xITl0#scrollTo=V33xVH-CWUCs | |
# Note that this script will download and unzip 1GB of photos for training. | |
# Make sure you have the appropriate permissions to use any images. | |
# CHANGE ME BEFORE RUNNING | |
STYLE_IMAGE_URL='STYLE_IMAGE_URL' | |
# Install requirements | |
pip install keras==2.2.4 tensorflow numpy matplotlib | |
pip install git+https://www.github.com/keras-team/keras-contrib.git | |
pip install git+https://github.com/apple/coremltools.git@master | |
# Clone the style transfer repository | |
git clone https://github.com/fritzlabs/fritz-models.git | |
cd fritz-models/style_transfer | |
# Add this directory to your python path | |
export PYTHONPATH=$PYTHONPATH:`pwd` | |
# Create a data directly. This is ignored by git | |
mkdir data/ | |
# Download the style image | |
wget -O data/style_image.jpg $STYLE_IMAGE_URL | |
# Download and unzip the training images from MSCOCO if we haven't already | |
if [ ! -d data/val2017 ]; then | |
wget -O data/val2017.zip http://images.cocodataset.org/zips/val2017.zip | |
unzip -d data/ data/val2017.zip | |
fi | |
# Convert images to TFRecord format | |
if [ ! -f data/training_images.tfrecord ]; then | |
python create_training_dataset.py \ | |
--output data/training_images.tfrecord \ | |
--image-dir data/val2017/ | |
fi | |
# Train the model | |
python style_transfer/train.py \ | |
--training-image-dset data/training_images.tfrecord \ | |
--style-images data/style_image.jpg \ | |
--model-checkpoint data/my_style_025.h5 \ | |
--image-size 256,256 \ | |
--alpha 0.25 \ | |
--num-iterations 2 \ | |
--batch-size 2 \ | |
--fine-tune-checkpoint example/starry_night_256x256_025.h5 | |
# Convert to Core ML | |
python convert_to_coreml.py \ | |
--keras-checkpoint data/my_style_025.h5 \ | |
--alpha 0.25 \ | |
--image-size 640,640 \ | |
--coreml-model data/my_style_025.mlmodel | |
# Convert to TensorFlow Mobile | |
python convert_to_tfmobile.py \ | |
--keras-checkpoint data/my_style_025.h5 \ | |
--alpha 0.25 \ | |
--image-size 640,640 \ | |
--output-dir data/ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment