Skip to content

Instantly share code, notes, and snippets.

@aoikonomop
Created December 15, 2017 18:00
Show Gist options
  • Save aoikonomop/20b374c680d45e8d3c0b39f474779514 to your computer and use it in GitHub Desktop.
Save aoikonomop/20b374c680d45e8d3c0b39f474779514 to your computer and use it in GitHub Desktop.
cascaded CNN training
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"import logging\n",
"\n",
"from hudl_beatrix.models import CascadedCNNModelConfig, cascadedcnn_model_factory\n",
"from hudl_beatrix.optimizers import RmsPropOptimizerConfig, AdamOptimizerConfig, LearningRateConfig\n",
"from hudl_beatrix.extractors import AlexNetLiteExtractorConfig, BranchExtractorConfig\n",
"\n",
"from hudl_beatrix.datasets.patch import PatchDataset\n",
"\n",
"from hudl_beatrix.datasets.downloaded import WatsonDatasetConfig\n",
"from hudl_beatrix.datasets.patch import PatchDatasetConfig, patch_dataset_factory\n",
"\n",
"logging.info('Test')\n",
"logger = logging.getLogger()\n",
"logger.setLevel(logging.INFO)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cascaded_cnn_config = CascadedCNNModelConfig(\n",
" base_extractor=AlexNetLiteExtractorConfig(),\n",
" branch_extractors=[\n",
" BranchExtractorConfig(branch_id=1, pool=True),\n",
"# BranchExtractorConfig(branch_id=2, pool=True),\n",
"# BranchExtractorConfig(branch_id=3, pool=False),\n",
"# BranchExtractorConfig(branch_id=4, pool=False)\n",
" ],\n",
" optimizer=AdamOptimizerConfig(learning_rate=LearningRateConfig(learning_rate=0.001))\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = cascadedcnn_model_factory(cascaded_cnn_config)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"estimator = tf.estimator.Estimator(model.__call__, model_dir='./train_cascade_dir/')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset_config = [WatsonDatasetConfig(root='patch01', image_method='imageio')]\n",
"patch_dataset_config = PatchDatasetConfig(datasets=dataset_config, \n",
" name='patch01', image_method='imageio', overwrite=False,\n",
" negative_ratio=1, batch_size=32)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset = patch_dataset_factory(patch_dataset_config)\n",
"dataset.extract_patches()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"input_fn_train = dataset.train_input_fn\n",
"input_fn_dev = dataset.dev_input_fn"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"for epoch in range(2000):\n",
" estimator.train(input_fn=input_fn_train, steps = None)\n",
" estimator.evaluate(input_fn=input_fn_dev)\n",
" print (epoch)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment