Skip to content

Instantly share code, notes, and snippets.

@sjchoi86
Created April 15, 2018 16:31
Show Gist options
  • Save sjchoi86/19cf18353718e3ea26a5355a7a50856d to your computer and use it in GitHub Desktop.
Save sjchoi86/19cf18353718e3ea26a5355a7a50856d to your computer and use it in GitHub Desktop.
mcdn/code/main_cifar10_config.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": "# CIFAR-10 configurations"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import os,nbloader,warnings,sys\nwarnings.filterwarnings(\"ignore\")\nimport numpy as np\nimport tensorflow as tf \nimport tensorflow.contrib.slim as slim\nfrom demo_cnn_cls import cnn_cls_class\nfrom demo_mcdn_cls import mcdn_cls_class\nfrom demo_util import gpusession,load_cifar_with_noise,grid_maker\nif __name__ == \"__main__\":\n print (\"Python version is [%s]\"%(sys.version_info[0]))",
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "Python version is [3]\n"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Common configuration & CNN and MCDN configurations"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "def get_cifar10_common_config():\n xdim = [32,32,3]\n ydim = 10\n filterSizes = [3,3,3,3,3,3,3,3,3,3]\n max_pools = [1,1,1,2,1,1,2,1,1,2]\n maxEpoch = 200\n batchSize = 256\n l2_reg_coef = 1e-6\n USE_RESNET = True \n USE_SGD = True\n PRINT_EVERY = maxEpoch\n VERBOSE = False\n DO_AUGMENTATION = True\n return xdim,ydim,filterSizes,max_pools \\\n ,maxEpoch,batchSize,l2_reg_coef \\\n ,USE_RESNET,USE_SGD,PRINT_EVERY,VERBOSE \\\n ,DO_AUGMENTATION\ndef get_cifar10_mcdn_config():\n actv = tf.nn.relu\n bn = slim.batch_norm\n rho_ref_train = 0.95 \n tau_inv = 1e-4\n pi1_bias = 0.0\n logSigmaZval = -3\n logsumexp_coef = 1e-4\n kl_reg_coef = 1e-4\n USE_KENDALL_LOSS = False\n return actv,bn,rho_ref_train,tau_inv,pi1_bias \\\n ,logSigmaZval,logsumexp_coef,kl_reg_coef,USE_KENDALL_LOSS\ndef get_cifar10_cnn_config():\n actv = tf.nn.relu\n bn = slim.batch_norm\n return actv,bn",
"execution_count": 2,
"outputs": []
},
{
"metadata": {
"scrolled": true,
"trusted": true
},
"cell_type": "code",
"source": "def get_properIdx(_processID,_maxProcessID,_nTask):\n ret = []\n if _processID > _nTask: return ret\n if _processID > _maxProcessID: return ret\n m = (_nTask-_processID-1) // _maxProcessID\n for i in range(m+1):\n ret.append(i*_maxProcessID+_processID)\n return ret\nif __name__ == \"__main__\":\n maxProcessID,nTask = 8,4\n print (\"\\nmaxProcessID:[%d], nTask:[%d]\"%(maxProcessID,nTask ))\n for processID in range(maxProcessID):\n ids = get_properIdx(_processID=processID,_maxProcessID=maxProcessID,_nTask=nTask)\n print (\" processID:[%d] %s\"%(processID,ids))\n \n maxProcessID,nTask = 8,11\n print (\"\\nmaxProcessID:[%d], nTask:[%d]\"%(maxProcessID,nTask ))\n for processID in range(maxProcessID):\n ids = get_properIdx(_processID=processID,_maxProcessID=maxProcessID,_nTask=nTask)\n print (\" processID:[%d] %s\"%(processID,ids))\n \n maxProcessID,nTask = 8,30\n print (\"\\nmaxProcessID:[%d], nTask:[%d]\"%(maxProcessID,nTask ))\n for processID in range(maxProcessID):\n ids = get_properIdx(_processID=processID,_maxProcessID=maxProcessID,_nTask=nTask)\n print (\" processID:[%d] %s\"%(processID,ids))",
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "\nmaxProcessID:[8], nTask:[4]\n processID:[0] [0]\n processID:[1] [1]\n processID:[2] [2]\n processID:[3] [3]\n processID:[4] []\n processID:[5] []\n processID:[6] []\n processID:[7] []\n\nmaxProcessID:[8], nTask:[11]\n processID:[0] [0, 8]\n processID:[1] [1, 9]\n processID:[2] [2, 10]\n processID:[3] [3]\n processID:[4] [4]\n processID:[5] [5]\n processID:[6] [6]\n processID:[7] [7]\n\nmaxProcessID:[8], nTask:[30]\n processID:[0] [0, 8, 16, 24]\n processID:[1] [1, 9, 17, 25]\n processID:[2] [2, 10, 18, 26]\n processID:[3] [3, 11, 19, 27]\n processID:[4] [4, 12, 20, 28]\n processID:[5] [5, 13, 21, 29]\n processID:[6] [6, 14, 22]\n processID:[7] [7, 15, 23]\n"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### GPU-specific configuration"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# Train with MCDN or Not\nTEST_MCDN_list = [True,False]\n# Use GAP or FCN \nUSE_GAP_list = [True,False] # <== Tuning parameter [True,False]\n# Number of K \nkmixList = [5,10,20]\n# Base channel size\nbaseChannelList = [32,64,128] # <== Tuning parameter [16,32,64]\n# Base learning rate & momentum\nlrBaseList = [1e-2,1e-1]\nmomentumList = [0.5]\n# Error type\nerrTypeList = ['rp','rs']\noutlierRatioList = [0.5,0.8]\n# Feature dimension\nfeatDimList = [128,256]\n\nif __name__ == \"__main__\":\n G = grid_maker(USE_GAP_list,kmixList,baseChannelList,lrBaseList\n ,momentumList,errTypeList,outlierRatioList,featDimList,TEST_MCDN_list)\n print (\"G.nIter:[%d]\"%(G.nIter))",
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "G.nIter:[576]\n"
}
]
},
{
"metadata": {
"scrolled": true,
"trusted": true
},
"cell_type": "code",
"source": "def get_cifgar10_config(_processID=0,_maxProcessID=8):\n # Get total configurations\n _G = grid_maker(USE_GAP_list,kmixList,baseChannelList,lrBaseList\n ,momentumList,errTypeList,outlierRatioList,featDimList,TEST_MCDN_list)\n # Get current configurations\n _ids = get_properIdx(_processID,_maxProcessID,_nTask=_G.nIter)\n _paramsList = list(_G.paramList[i] for i in _ids) \n # Set GPU ID\n if _processID == 0: _GPU_ID = 0\n elif _processID == 1: _GPU_ID = 1\n elif _processID == 2: _GPU_ID = 2\n elif _processID == 3: _GPU_ID = 3\n elif _processID == 4: _GPU_ID = 4\n elif _processID == 5: _GPU_ID = 5\n elif _processID == 6: _GPU_ID = 6\n elif _processID == 7: _GPU_ID = 7\n elif _processID == 8: _GPU_ID = 0\n elif _processID == 9: _GPU_ID = 1\n elif _processID == 10: _GPU_ID = 2\n elif _processID == 11: _GPU_ID = 3\n elif _processID == 12: _GPU_ID = 4\n elif _processID == 13: _GPU_ID = 5\n elif _processID == 14: _GPU_ID = 6\n elif _processID == 15: _GPU_ID = 7 \n # Retrun \n return _paramsList,_GPU_ID\n\nif __name__ == \"__main__\":\n processID = 0\n maxProcessID = 8\n paramsList,GPU_ID = get_cifgar10_config(processID,maxProcessID)\n print (\"GPU_ID:[%d]\"%(GPU_ID))\n for pIdx,params in enumerate(paramsList): # For all current configurations\n print (pIdx,params) \n ",
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "GPU_ID:[0]\n0 (True, 5, 32, 0.01, 0.5, 'rp', 0.5, 128, True)\n1 (True, 5, 32, 0.01, 0.5, 'rs', 0.5, 128, True)\n2 (True, 5, 32, 0.1, 0.5, 'rp', 0.5, 128, True)\n3 (True, 5, 32, 0.1, 0.5, 'rs', 0.5, 128, True)\n4 (True, 5, 64, 0.01, 0.5, 'rp', 0.5, 128, True)\n5 (True, 5, 64, 0.01, 0.5, 'rs', 0.5, 128, True)\n6 (True, 5, 64, 0.1, 0.5, 'rp', 0.5, 128, True)\n7 (True, 5, 64, 0.1, 0.5, 'rs', 0.5, 128, True)\n8 (True, 5, 128, 0.01, 0.5, 'rp', 0.5, 128, True)\n9 (True, 5, 128, 0.01, 0.5, 'rs', 0.5, 128, True)\n10 (True, 5, 128, 0.1, 0.5, 'rp', 0.5, 128, True)\n11 (True, 5, 128, 0.1, 0.5, 'rs', 0.5, 128, True)\n12 (True, 10, 32, 0.01, 0.5, 'rp', 0.5, 128, True)\n13 (True, 10, 32, 0.01, 0.5, 'rs', 0.5, 128, True)\n14 (True, 10, 32, 0.1, 0.5, 'rp', 0.5, 128, True)\n15 (True, 10, 32, 0.1, 0.5, 'rs', 0.5, 128, True)\n16 (True, 10, 64, 0.01, 0.5, 'rp', 0.5, 128, True)\n17 (True, 10, 64, 0.01, 0.5, 'rs', 0.5, 128, True)\n18 (True, 10, 64, 0.1, 0.5, 'rp', 0.5, 128, True)\n19 (True, 10, 64, 0.1, 0.5, 'rs', 0.5, 128, True)\n20 (True, 10, 128, 0.01, 0.5, 'rp', 0.5, 128, True)\n21 (True, 10, 128, 0.01, 0.5, 'rs', 0.5, 128, True)\n22 (True, 10, 128, 0.1, 0.5, 'rp', 0.5, 128, True)\n23 (True, 10, 128, 0.1, 0.5, 'rs', 0.5, 128, True)\n24 (True, 20, 32, 0.01, 0.5, 'rp', 0.5, 128, True)\n25 (True, 20, 32, 0.01, 0.5, 'rs', 0.5, 128, True)\n26 (True, 20, 32, 0.1, 0.5, 'rp', 0.5, 128, True)\n27 (True, 20, 32, 0.1, 0.5, 'rs', 0.5, 128, True)\n28 (True, 20, 64, 0.01, 0.5, 'rp', 0.5, 128, True)\n29 (True, 20, 64, 0.01, 0.5, 'rs', 0.5, 128, True)\n30 (True, 20, 64, 0.1, 0.5, 'rp', 0.5, 128, True)\n31 (True, 20, 64, 0.1, 0.5, 'rs', 0.5, 128, True)\n32 (True, 20, 128, 0.01, 0.5, 'rp', 0.5, 128, True)\n33 (True, 20, 128, 0.01, 0.5, 'rs', 0.5, 128, True)\n34 (True, 20, 128, 0.1, 0.5, 'rp', 0.5, 128, True)\n35 (True, 20, 128, 0.1, 0.5, 'rs', 0.5, 128, True)\n36 (False, 5, 32, 0.01, 0.5, 'rp', 0.5, 128, True)\n37 (False, 5, 32, 0.01, 0.5, 'rs', 0.5, 128, True)\n38 (False, 5, 32, 0.1, 0.5, 'rp', 0.5, 128, True)\n39 (False, 5, 32, 0.1, 0.5, 'rs', 0.5, 128, True)\n40 (False, 5, 64, 0.01, 0.5, 'rp', 0.5, 128, True)\n41 (False, 5, 64, 0.01, 0.5, 'rs', 0.5, 128, True)\n42 (False, 5, 64, 0.1, 0.5, 'rp', 0.5, 128, True)\n43 (False, 5, 64, 0.1, 0.5, 'rs', 0.5, 128, True)\n44 (False, 5, 128, 0.01, 0.5, 'rp', 0.5, 128, True)\n45 (False, 5, 128, 0.01, 0.5, 'rs', 0.5, 128, True)\n46 (False, 5, 128, 0.1, 0.5, 'rp', 0.5, 128, True)\n47 (False, 5, 128, 0.1, 0.5, 'rs', 0.5, 128, True)\n48 (False, 10, 32, 0.01, 0.5, 'rp', 0.5, 128, True)\n49 (False, 10, 32, 0.01, 0.5, 'rs', 0.5, 128, True)\n50 (False, 10, 32, 0.1, 0.5, 'rp', 0.5, 128, True)\n51 (False, 10, 32, 0.1, 0.5, 'rs', 0.5, 128, True)\n52 (False, 10, 64, 0.01, 0.5, 'rp', 0.5, 128, True)\n53 (False, 10, 64, 0.01, 0.5, 'rs', 0.5, 128, True)\n54 (False, 10, 64, 0.1, 0.5, 'rp', 0.5, 128, True)\n55 (False, 10, 64, 0.1, 0.5, 'rs', 0.5, 128, True)\n56 (False, 10, 128, 0.01, 0.5, 'rp', 0.5, 128, True)\n57 (False, 10, 128, 0.01, 0.5, 'rs', 0.5, 128, True)\n58 (False, 10, 128, 0.1, 0.5, 'rp', 0.5, 128, True)\n59 (False, 10, 128, 0.1, 0.5, 'rs', 0.5, 128, True)\n60 (False, 20, 32, 0.01, 0.5, 'rp', 0.5, 128, True)\n61 (False, 20, 32, 0.01, 0.5, 'rs', 0.5, 128, True)\n62 (False, 20, 32, 0.1, 0.5, 'rp', 0.5, 128, True)\n63 (False, 20, 32, 0.1, 0.5, 'rs', 0.5, 128, True)\n64 (False, 20, 64, 0.01, 0.5, 'rp', 0.5, 128, True)\n65 (False, 20, 64, 0.01, 0.5, 'rs', 0.5, 128, True)\n66 (False, 20, 64, 0.1, 0.5, 'rp', 0.5, 128, True)\n67 (False, 20, 64, 0.1, 0.5, 'rs', 0.5, 128, True)\n68 (False, 20, 128, 0.01, 0.5, 'rp', 0.5, 128, True)\n69 (False, 20, 128, 0.01, 0.5, 'rs', 0.5, 128, True)\n70 (False, 20, 128, 0.1, 0.5, 'rp', 0.5, 128, True)\n71 (False, 20, 128, 0.1, 0.5, 'rs', 0.5, 128, True)\n"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### CIFAR-10 Train wrapper"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "def train_wrapper_cifar10(_paramsList,_GPU_ID):\n \n # Train on different configurations \n for pIdx,params in enumerate(_paramsList): # For all current configurations\n \n # Parse current configuration\n USE_GAP,kmix,base_channel,lr_base,momentum,errType,outlierRatio,feat_dim,TEST_MCDN \\\n = params[0],params[1],params[2],params[3],params[4],params[5],params[6],params[7],params[8]\n _bs = base_channel \n hdims = [_bs,_bs,_bs,_bs,_bs*2,_bs*2,_bs*2,_bs*4,_bs*4,_bs*4]\n \n # Load CIFAR-10 with outlier\n trainimg,trainlabel,testimg,testlabel,valimg,vallabel \\\n = load_cifar_with_noise(_errType=errType,_outlierRatio=outlierRatio,_seed=0)\n \n # Common configurations\n xdim,ydim,filterSizes,max_pools \\\n ,maxEpoch,batchSize,l2_reg_coef \\\n ,USE_RESNET,USE_SGD,PRINT_EVERY,VERBOSE \\\n ,DO_AUGMENTATION = get_cifar10_common_config()\n \n # Proposed MCDN\n if TEST_MCDN == True:\n # MCDN configuration\n actv,bn,rho_ref_train,tau_inv,pi1_bias \\\n ,logSigmaZval,logsumexp_coef,kl_reg_coef,USE_KENDALL_LOSS \\\n = get_cifar10_mcdn_config()\n tf.reset_default_graph(); tf.set_random_seed(0); np.random.seed(0)\n MCDN = mcdn_cls_class(_name=('cifar10_%s_err%.0f_mcdn_ch%d_k%d_fdim%d_bs%d_lr%.2e_mmt%.2e_%s'\n %(errType,outlierRatio*100,base_channel,kmix,feat_dim,batchSize,lr_base,momentum\n ,('GAP' if USE_GAP else 'FCN')))\n ,_xdim=xdim,_ydim=ydim,_hdims=hdims,_filterSizes=filterSizes\n ,_max_pools=max_pools,_feat_dim=feat_dim\n ,_kmix=kmix,_actv=actv,_bn=slim.batch_norm \n ,_rho_ref_train=rho_ref_train,_tau_inv=tau_inv,_pi1_bias=pi1_bias,_logSigmaZval=logSigmaZval\n ,_logsumexp_coef=logsumexp_coef,_kl_reg_coef=kl_reg_coef,_l2_reg_coef=l2_reg_coef\n ,_USE_RESNET=USE_RESNET,_USE_GAP=USE_GAP,_USE_KENDALL_LOSS=USE_KENDALL_LOSS,_USE_SGD=USE_SGD\n ,_GPU_ID=_GPU_ID,_VERBOSE=VERBOSE)\n sess = gpusession(); sess.run(tf.global_variables_initializer()) \n MCDN.train(_sess=sess,_trainimg=trainimg,_trainlabel=trainlabel\n ,_testimg=testimg,_testlabel=testlabel,_valimg=valimg,_vallabel=vallabel\n ,_maxEpoch=maxEpoch,_batchSize=batchSize,_lr=lr_base\n ,_LR_SCHEDULE=True,_PRINT_EVERY=PRINT_EVERY,_SAVE_BEST=True\n ,_DO_AUGMENTATION=DO_AUGMENTATION)\n sess.close()\n\n # Baseline CNN\n if TEST_MCDN != True:\n # CNN congfiguration\n actv,bn = get_cifar10_cnn_config()\n tf.reset_default_graph(); tf.set_random_seed(0); np.random.seed(0)\n CNN = cnn_cls_class(_name=('cifar10_%s_err%.0f_cnn_ch%d_fdim%d_bs%d_lr%.2e_mmt%.2e_%s'\n %(errType,outlierRatio*100,base_channel,feat_dim,batchSize,lr_base,momentum\n ,('GAP' if USE_GAP else 'FCN')))\n ,_xdim=xdim,_ydim=ydim,_hdims=hdims,_filterSizes=filterSizes\n ,_max_pools=max_pools,_feat_dim=feat_dim\n ,_actv=actv,_bn=bn,_l2_reg_coef=l2_reg_coef\n ,_USE_RESNET=USE_RESNET,_USE_GAP=USE_GAP,_USE_SGD=USE_SGD\n ,_GPU_ID=_GPU_ID,_VERBOSE=VERBOSE)\n sess = gpusession(); sess.run(tf.global_variables_initializer()) \n CNN.train(_sess=sess,_trainimg=trainimg,_trainlabel=trainlabel\n ,_testimg=testimg,_testlabel=testlabel,_valimg=valimg,_vallabel=vallabel\n ,_maxEpoch=maxEpoch,_batchSize=batchSize,_lr=lr_base*0.1\n ,_LR_SCHEDULE=True,_PRINT_EVERY=PRINT_EVERY,_SAVE_BEST=True\n ,_DO_AUGMENTATION=DO_AUGMENTATION)\n sess.close()\n \nif __name__ == \"__main__\":\n processID = 0\n maxProcessID = 8\n paramsList,GPU_ID = get_cifgar10_config(processID,maxProcessID)\n # train_wrapper_cifar10(_paramsList=paramsList,_GPU_ID=GPU_ID)",
"execution_count": 6,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"language_info": {
"mimetype": "text/x-python",
"pygments_lexer": "ipython3",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"name": "python",
"version": "3.5.2",
"nbconvert_exporter": "python"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
},
"gist": {
"id": "",
"data": {
"description": "mcdn/code/main_cifar10_config.ipynb",
"public": true
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment