Skip to content

Instantly share code, notes, and snippets.

@wenfahu
Last active September 4, 2017 02:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wenfahu/3496d29b85b8a02ee5df9fdc64678ad7 to your computer and use it in GitHub Desktop.
Save wenfahu/3496d29b85b8a02ee5df9fdc64678ad7 to your computer and use it in GitHub Desktop.

Overall process

This toolset provides channel level pruning of inception-renet v2 model( the details of inception resnet v2 model, please refer to .. _Inception-ResnetV2: https://arxiv.org/abs/1602.07261

  1. python inf.py [meta] [ckpt] [output_mask] --threshold [threshold] : get the indices (mask) for the convolutional channel weights under the threshold.
  2. python freeze_graph.py [model_dir] [output_file]: freeze the model weights
  3. simplify the tensorflow graph using .. GTT: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/README.md#optimizing-for-deployment
  4. python conv_travser.py [graph_path] [mask_path] [output_graph] [output_mask]: prune the model based on the dependency of inception resnet v2, the output_graph is the pruned model and the output_mask is used for further training.
  5. python zero_ckpt.py [meta] [ckpt] [zidx] [output]: set the pruned model weights using generated mask above.

(Re)Training

train_pruned_classifer.py [-h] [--logs_base_dir LOGS_BASE_DIR]

[--models_base_dir MODELS_BASE_DIR] [--gpu_memory_fraction GPU_MEMORY_FRACTION] [--pretrained_model PRETRAINED_MODEL] [--data_dir DATA_DIR] [--model_def MODEL_DEF] [--max_nrof_epochs MAX_NROF_EPOCHS] [--batch_size BATCH_SIZE] [--image_size IMAGE_SIZE] [--epoch_size EPOCH_SIZE] [--embedding_size EMBEDDING_SIZE] [--random_crop] [--random_flip] [--random_rotate] [--keep_probability KEEP_PROBABILITY] [--weight_decay WEIGHT_DECAY] [--decov_loss_factor DECOV_LOSS_FACTOR] [--center_loss_factor CENTER_LOSS_FACTOR] [--center_loss_alfa CENTER_LOSS_ALFA] [--optimizer {ADAGRAD,ADADELTA,ADAM,RMSPROP,MOM}] [--learning_rate LEARNING_RATE] [--learning_rate_decay_epochs LEARNING_RATE_DECAY_EPOCHS] [--learning_rate_decay_factor LEARNING_RATE_DECAY_FACTOR] [--moving_average_decay MOVING_AVERAGE_DECAY] [--seed SEED] [--nrof_preprocess_threads NROF_PREPROCESS_THREADS] [--log_histograms] [--learning_rate_schedule_file LEARNING_RATE_SCHEDULE_FILE] [--filter_filename FILTER_FILENAME] [--filter_percentile FILTER_PERCENTILE] [--filter_min_nrof_images_per_class FILTER_MIN_NROF_IMAGES_PER_CLASS] [--no_store_revision_info] [--lfw_pairs LFW_PAIRS] [--lfw_file_ext {jpg,png}] [--lfw_dir LFW_DIR] [--lfw_batch_size LFW_BATCH_SIZE] [--lfw_nrof_folds LFW_NROF_FOLDS] [--finetune] [--pruning_mask PRUNING_MASK] [--meta_graph META_GRAPH] [--group_lasso_factor GROUP_LASSO_FACTOR]

optional arguments:

-h, --help show this help message and exit --logs_base_dir LOGS_BASE_DIR Directory where to write event logs. --models_base_dir MODELS_BASE_DIR Directory where to write trained models and checkpoints. --gpu_memory_fraction GPU_MEMORY_FRACTION Upper bound on the amount of GPU memory that will be used by the process. --pretrained_model PRETRAINED_MODEL Load a pretrained model before training starts. --data_dir DATA_DIR Path to the data directory containing aligned face patches. Multiple directories are separated with colon. --model_def MODEL_DEF Model definition. Points to a module containing the definition of the inference graph. --max_nrof_epochs MAX_NROF_EPOCHS Number of epochs to run. --batch_size BATCH_SIZE Number of images to process in a batch. --image_size IMAGE_SIZE Image size (height, width) in pixels. --epoch_size EPOCH_SIZE Number of batches per epoch. --embedding_size EMBEDDING_SIZE Dimensionality of the embedding. --random_crop Performs random cropping of training images. If false, the center image_size pixels from the training images are used. If the size of the images in the data directory is equal to image_size no cropping is performed --random_flip Performs random horizontal flipping of training images. --random_rotate Performs random rotations of training images. --keep_probability KEEP_PROBABILITY Keep probability of dropout for the fully connected layer(s). --weight_decay WEIGHT_DECAY L2 weight regularization. --decov_loss_factor DECOV_LOSS_FACTOR DeCov loss factor. --center_loss_factor CENTER_LOSS_FACTOR Center loss factor. --center_loss_alfa CENTER_LOSS_ALFA Center update rate for center loss. --optimizer {ADAGRAD,ADADELTA,ADAM,RMSPROP,MOM} The optimization algorithm to use --learning_rate LEARNING_RATE Initial learning rate. If set to a negative value a learning rate schedule can be specified in the file "learning_rate_schedule.txt" --learning_rate_decay_epochs LEARNING_RATE_DECAY_EPOCHS Number of epochs between learning rate decay. --learning_rate_decay_factor LEARNING_RATE_DECAY_FACTOR Learning rate decay factor. --moving_average_decay MOVING_AVERAGE_DECAY Exponential decay for tracking of training parameters. --seed SEED Random seed. --nrof_preprocess_threads NROF_PREPROCESS_THREADS Number of preprocessing (data loading and augumentation) threads. --log_histograms Enables logging of weight/bias histograms in tensorboard. --learning_rate_schedule_file LEARNING_RATE_SCHEDULE_FILE File containing the learning rate schedule that is used when learning_rate is set to to -1. --filter_filename FILTER_FILENAME File containing image data used for dataset filtering --filter_percentile FILTER_PERCENTILE Keep only the percentile images closed to its class center --filter_min_nrof_images_per_class FILTER_MIN_NROF_IMAGES_PER_CLASS Keep only the classes with this number of examples or more --no_store_revision_info Disables storing of git revision info in revision_info.txt. --lfw_pairs LFW_PAIRS The file containing the pairs to use for validation. --lfw_file_ext {jpg,png} The file extension for the LFW dataset. --lfw_dir LFW_DIR Path to the data directory containing aligned face patches. --lfw_batch_size LFW_BATCH_SIZE Number of images to process in a batch in the LFW test set. --lfw_nrof_folds LFW_NROF_FOLDS Number of folds to use for cross validation. Mainly used for testing. --finetune fine tune the model --pruning_mask PRUNING_MASK pruning mask for back prop --meta_graph META_GRAPH Load a pretrained metagraph before training starts. --group_lasso_factor GROUP_LASSO_FACTOR scale for group lasso regularization

the group_lasso_factor the generally 8e-5 for inception resnet v2 model and if the pruning_mask is provided, the gradients of the in the mask is zeroed out, thus preventing the pruned parameters from updating.

This toolset is heavily dependent on the David Sandberg's .. _facenet: https://github.com/davidsandberg/facenet. The model is trained on the MS Celeb 1M .. dataset: https://www.microsoft.com/en-us/research/project/ms-celeb-1m-challenge-recognizing-one-million-celebrities-real-world/. And the lfw face verification .. _protocol: http://vis-www.cs.umass.edu/lfw/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment