-
-
Save T-STAR-LTD/6332d8b846ae4810576f1dc8374c3321 to your computer and use it in GitHub Desktop.
Online Label Smoothing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"accelerator": "TPU", | |
"colab": { | |
"name": "Online Label Smoothing", | |
"provenance": [], | |
"collapsed_sections": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"name": "python3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/T-STAR-LTD/6332d8b846ae4810576f1dc8374c3321/online-label-smoothing.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "HDAo0_qCnTAC", | |
"outputId": "878afa2a-0aed-435e-c810-370c0bac5eef" | |
}, | |
"source": [ | |
"import tensorflow as tf\n", | |
"import tensorflow.keras.backend as K\n", | |
"import numpy as np\n", | |
"import time\n", | |
"import datetime\n", | |
"from pytz import timezone\n", | |
"\n", | |
"tz = timezone('Asia/Tokyo')\n", | |
"start_datetime = datetime.datetime.now(tz=tz )\n", | |
"print(start_datetime.isoformat())\n", | |
"\n", | |
"train_data, validation_data = tf.keras.datasets.cifar10.load_data()\n", | |
"\n" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"2021-03-30T16:54:57.868533+09:00\n", | |
"Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz\n", | |
"170500096/170498071 [==============================] - 2s 0us/step\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "OEXmnjqrnvTo" | |
}, | |
"source": [ | |
"def basic_augmentation(x):\n", | |
" x = tf.image.random_flip_left_right(x)\n", | |
" x = tf.pad(x, [ [4,4], [4,4], [0,0]], \"reflect\")\n", | |
" x = tf.image.random_crop(x, [32,32,3])\n", | |
" return x\n", | |
"\n", | |
"\n", | |
"def cutout_sub(images, probability=None, cval=0, csize = 0.25):\n", | |
" DIM = images.shape[0]\n", | |
" x = tf.cast( tf.random.uniform([],0,DIM),tf.int32)\n", | |
" y = tf.cast( tf.random.uniform([],0,DIM),tf.int32)\n", | |
"\n", | |
" WIDTH = tf.cast( csize*DIM,tf.int32)\n", | |
" ya = tf.math.maximum(0,y-WIDTH//2)\n", | |
" yb = tf.math.minimum(DIM,y+WIDTH//2)\n", | |
" xa = tf.math.maximum(0,x-WIDTH//2)\n", | |
" xb = tf.math.minimum(DIM,x+WIDTH//2)\n", | |
"\n", | |
" one = images[ya:yb,0:xa,:]\n", | |
" two = tf.fill([yb-ya,xb-xa,3], tf.cast(cval, images.dtype) ) \n", | |
" three = images[ya:yb,xb:DIM,:]\n", | |
" middle = tf.concat([one,two,three],axis=1)\n", | |
" images2 = tf.concat([images[0:ya,:,:],middle,images[yb:DIM,:,:]],axis=0)\n", | |
"\n", | |
" images2 = tf.reshape(images2,[DIM,DIM,3])\n", | |
"\n", | |
" if probability != None:\n", | |
" images2 = tf.where( tf.random.uniform([],0.0,1.0)<=probability, images2, images)\n", | |
"\n", | |
" return images2\n", | |
"\n", | |
"def mix_sub(images, labels,num_classes, probability, j, mix_type):\n", | |
" img = images[j]\n", | |
" label = labels[j]\n", | |
" \n", | |
" DIM = img.shape[0]\n", | |
"\n", | |
" k = tf.random.uniform([],0,num_classes,dtype=tf.int32)\n", | |
" is_same_class = tf.reduce_all(tf.math.equal(label, labels[k]))\n", | |
" k = tf.where(is_same_class, (k+1)%num_classes, k)\n", | |
" is_same_class = tf.reduce_all(tf.math.equal(label, labels[k]))\n", | |
" k = tf.where(is_same_class, (k+1)%num_classes, k)\n", | |
"\n", | |
" P = tf.cast( tf.random.uniform([],0.0,1.0)<=probability, tf.int32)\n", | |
"\n", | |
" if mix_type==0:\n", | |
" a = tf.random.uniform([],0,0.5) * tf.cast(P, tf.float32)\n", | |
" img1 = images[j,]\n", | |
" img2 = images[k,]\n", | |
" img = (1-a)*img1 + a*img2\n", | |
" img = tf.reshape(img,(DIM,DIM,3))\n", | |
" else: \n", | |
" b = tf.random.uniform([],0,1)\n", | |
" w = tf.cast( DIM * tf.math.sqrt(1-b),tf.int32) * P\n", | |
" x = tf.cast( tf.random.uniform([],0,DIM),tf.int32)\n", | |
" y = tf.cast( tf.random.uniform([],0,DIM),tf.int32)\n", | |
" ya = tf.math.maximum(0,y-w//2)\n", | |
" yb = tf.math.minimum(DIM,y+w//2)\n", | |
" xa = tf.math.maximum(0,x-w//2)\n", | |
" xb = tf.math.minimum(DIM,x+w//2)\n", | |
"\n", | |
" one = img[ya:yb,0:xa,:]\n", | |
" two = images[k,ya:yb,xa:xb,:]\n", | |
" three = img[ya:yb,xb:DIM,:]\n", | |
" middle = tf.concat([one,two,three],axis=1)\n", | |
" img2 = tf.concat([img[0:ya,:,:],middle,img[yb:DIM,:,:]],axis=0)\n", | |
" img = tf.reshape(img2,(DIM,DIM,3))\n", | |
" \n", | |
" a = tf.cast(w*w/DIM/DIM,tf.float32)\n", | |
"\n", | |
" lab1 = tf.one_hot(labels[j],num_classes)\n", | |
" lab2 = tf.one_hot(labels[k],num_classes)\n", | |
" label = (1-a)*lab1 + a*lab2\n", | |
" label = tf.reshape(label,(num_classes,))\n", | |
"\n", | |
" return img,label\n", | |
"\n", | |
"@tf.function\n", | |
"def cutout(images, labels, num_classes, probability = 1.0, cval=0.0, csize=0.5):\n", | |
" images2 = tf.map_fn(lambda image: \n", | |
" cutout_sub(image, probability=probability, cval=cval, csize=csize),\n", | |
" images)\n", | |
" labels2 = to_one_hot(labels,num_classes)\n", | |
" return images2,labels2\n", | |
"\n", | |
"@tf.function\n", | |
"def mixup(images, labels, num_classes, probability = 1.0):\n", | |
" images = tf.cast(images, tf.float32)\n", | |
" BATCH_SIZE = images.shape[0]\n", | |
" DIM = images.shape[1]\n", | |
"\n", | |
" elems = tf.range(0,BATCH_SIZE,dtype=tf.int32)\n", | |
" images2, labels2 = tf.map_fn(lambda x: mix_sub(images,labels, num_classes, probability, x, 0), elems, fn_output_signature=(tf.float32,tf.float32))\n", | |
" image2 = tf.reshape(images,(BATCH_SIZE,DIM,DIM,3))\n", | |
" labels2 = tf.reshape(labels2,(BATCH_SIZE,num_classes))\n", | |
" return images2,labels2\n", | |
"\n", | |
"\n", | |
"@tf.function\n", | |
"def cutmix(images, labels, num_classes, probability = 1.0):\n", | |
" images = tf.cast(images, tf.float32)\n", | |
" BATCH_SIZE = images.shape[0]\n", | |
" DIM = images.shape[1]\n", | |
"\n", | |
" elems = tf.range(0,BATCH_SIZE,dtype=tf.int32)\n", | |
"\n", | |
" images2, labels2 = tf.map_fn(lambda x: mix_sub(images,labels, num_classes, probability, x, 1), elems, fn_output_signature=(tf.float32,tf.float32))\n", | |
" image2 = tf.reshape(images,(BATCH_SIZE,DIM,DIM,3))\n", | |
" labels2 = tf.reshape(labels2,(BATCH_SIZE,num_classes))\n", | |
" return images2,labels2\n", | |
"\n", | |
"\n", | |
"def to_one_hot(labels, num_classes):\n", | |
" # num_batch = labels.shape[0]\n", | |
" labels2 = tf.one_hot(labels, num_classes)\n", | |
" labels2 = tf.reshape(labels2, [-1, num_classes])\n", | |
" return labels2\n", | |
"\n", | |
"def make_dataset( train_data, validation_data, batch_size, num_classes, aug_type=\"basic\", probability=1.0, cutout_size=0.25):\n", | |
" (x_train, label_train)= train_data\n", | |
" (x_test, label_test)=validation_data\n", | |
"\n", | |
" train_len = len(x_train)\n", | |
" test_len = len(x_test)\n", | |
"\n", | |
" B, W, H, C = x_train.shape\n", | |
" mean_val = np.mean( np.reshape(x_train, (B*W*H, C)), axis=0 ) / 255.0\n", | |
" std_val = np.std( np.reshape(x_train, (B*W*H, C)), axis=0 ) / 255.0\n", | |
"\n", | |
" ds_train = tf.data.Dataset.from_tensor_slices(train_data)\n", | |
" ds_train = ds_train.shuffle(train_len).batch(batch_size, drop_remainder=True)\n", | |
" ds_validation = tf.data.Dataset.from_tensor_slices(validation_data)\n", | |
" ds_validation = ds_validation.batch(batch_size)\n", | |
"\n", | |
" # @tf.function\n", | |
" def data_augmentation(images,labels,aug=True,cutout_size=0.25):\n", | |
"\n", | |
" images = tf.cast(images, tf.float32)/255.0\n", | |
" image_mean = tf.constant(mean_val, dtype=tf.float32)\n", | |
" image_std = tf.constant(std_val, dtype=tf.float32)\n", | |
" images = (images-image_mean)/image_std\n", | |
" if aug:\n", | |
" images = tf.map_fn(basic_augmentation, images)\n", | |
" if aug_type==\"cutout\":\n", | |
" images,labels = cutout(images, labels, num_classes,probability, cval=0.0,csize=cutout_size)\n", | |
" elif aug_type==\"cutmix\":\n", | |
" images,labels = cutmix(images, labels, num_classes,probability)\n", | |
" elif aug_type==\"mixup\":\n", | |
" images,labels = mixup(images, labels, num_classes,probability)\n", | |
" else:\n", | |
" print('basic augmentation')\n", | |
" labels = to_one_hot(labels,num_classes)\n", | |
"\n", | |
" else:\n", | |
" labels = to_one_hot(labels,num_classes)\n", | |
" return images, labels\n", | |
"\n", | |
" ds_train = ds_train.map(lambda image, label: data_augmentation(image, label, True, cutout_size ), num_parallel_calls=tf.data.experimental.AUTOTUNE)\n", | |
" ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)\n", | |
"\n", | |
" ds_validation = ds_validation.map(lambda image, label: data_augmentation(image, label, False ), num_parallel_calls=tf.data.experimental.AUTOTUNE)\n", | |
" ds_validation = ds_validation.prefetch(tf.data.experimental.AUTOTUNE)\n", | |
" return ds_train,ds_validation\n", | |
"\n", | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"def showHistory(history):\n", | |
"\n", | |
" # Setting Parameters\n", | |
" acc = history['acc']\n", | |
" val_acc = history['val_acc']\n", | |
" print( 'max_acc', max(val_acc) )\n", | |
"\n", | |
" loss = history['loss']\n", | |
" val_loss = history['val_loss']\n", | |
"\n", | |
" epochs = range(len(acc))\n", | |
"\n", | |
" plt.figure(figsize=(16,6))\n", | |
" \n", | |
" # Accracy\n", | |
" plt.subplot(1,2,1)\n", | |
" plt.plot(epochs, acc, 'r', label='Training')\n", | |
" plt.plot(epochs, val_acc, 'b', label='Validation')\n", | |
" plt.title('Accuracy')\n", | |
" plt.grid()\n", | |
" plt.legend()\n", | |
"\n", | |
" # Loss \n", | |
" plt.subplot(1,2,2)\n", | |
" plt.plot(epochs, loss, 'r', label='Training')\n", | |
" plt.plot(epochs, val_loss, 'b', label='Validation')\n", | |
" plt.title('Loss')\n", | |
" plt.grid()\n", | |
" plt.legend()\n", | |
" plt.show()\n", | |
"\n" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "E7coBk2LhuU3" | |
}, | |
"source": [ | |
"# https://arxiv.org/abs/2011.12562\n", | |
"# https://github.com/zhangchbin/OnlineLabelSmooth\n", | |
"class ols_categorical_crossentropy(tf.keras.losses.Loss):\n", | |
" def __init__(self, num_classes, steps_per_epoch, alpha=0.5, name=\"ols_categorical_crossentropy\"):\n", | |
" super().__init__(name=name)\n", | |
" self.num_classes = num_classes\n", | |
" self.steps_per_epoch = steps_per_epoch\n", | |
" self.alpha = alpha\n", | |
" self.steps = tf.Variable(0, dtype=tf.int32, name=\"stepcounter\")\n", | |
" self.training = tf.Variable(True, dtype=tf.bool, name=\"training\")\n", | |
" self.S = tf.Variable(\n", | |
" tf.zeros(shape=tf.TensorShape([num_classes,num_classes]), dtype=tf.float32), \n", | |
" shape=tf.TensorShape([num_classes,num_classes]), dtype=tf.float32, name=\"S\")\n", | |
"\n", | |
" self.hard_label = tf.eye(num_classes, dtype=tf.float32)\n", | |
"\n", | |
" self.soft_label = tf.Variable(\n", | |
" np.ones((num_classes, num_classes),dtype='float32')/num_classes,\n", | |
" shape=tf.TensorShape([num_classes,num_classes]), dtype=tf.float32, name=\"softlabel\")\n", | |
"\n", | |
" def set_training_phase(self, flag):\n", | |
" self.training.assign( flag )\n", | |
"\n", | |
"\n", | |
"\n", | |
" def call(self, y_true, y_pred, **kwargs):\n", | |
" \n", | |
" indices_pred = tf.math.argmax(y_pred,axis=1)\n", | |
" indices_true = tf.math.argmax(y_true,axis=1)\n", | |
"\n", | |
" softlabel = self.alpha*self.hard_label + (1.0-self.alpha)*self.soft_label\n", | |
" y_true_soft = tf.gather(softlabel, indices_true)\n", | |
" y_true_soft = tf.squeeze(y_true_soft)\n", | |
" loss = tf.keras.losses.categorical_crossentropy(y_true_soft, y_pred)\n", | |
" \n", | |
" def noop():\n", | |
" pass\n", | |
"\n", | |
" def online():\n", | |
" def update():\n", | |
" # update steps\n", | |
" self.steps.assign_add(1)\n", | |
"\n", | |
" # update S\n", | |
" correct_indices = tf.where( tf.math.equal(indices_pred, indices_true) )\n", | |
" correct_labels = tf.gather(indices_pred, correct_indices)\n", | |
" correct_p = tf.gather(y_pred, correct_indices)\n", | |
"\n", | |
" correct_p = tf.squeeze( correct_p, axis=1)\n", | |
" S = tf.tensor_scatter_nd_add( self.S, correct_labels, correct_p )\n", | |
" self.S.assign(S)\n", | |
" \n", | |
" def update_and_reset():\n", | |
" update()\n", | |
"\n", | |
" # update softlabel\n", | |
" S = tf.distribute.get_replica_context().all_reduce('sum', self.S) # for TPU\n", | |
" norm = tf.math.reduce_sum( S, axis=1)\n", | |
" soft_label = tf.where(norm>0, tf.transpose(S)/norm, tf.transpose(self.soft_label))\n", | |
" self.soft_label.assign(tf.transpose(soft_label))\n", | |
" # reset\n", | |
" self.S.assign_sub(self.S)\n", | |
" self.steps.assign(0)\n", | |
"\n", | |
" tf.cond(self.steps == self.steps_per_epoch, update_and_reset, update)\n", | |
"\n", | |
" tf.cond(self.training==True, online, noop)\n", | |
" return loss\n", | |
" \n", | |
"class OLSModel(tf.keras.models.Model):\n", | |
" def __init__(self, model, num_classes, steps_per_epoch, **kwargs):\n", | |
" self.lossobj = ols_categorical_crossentropy(num_classes, steps_per_epoch)\n", | |
" in_shape = model.input_shape[1:]\n", | |
" inputs = tf.keras.layers.Input(shape=in_shape)\n", | |
" super().__init__( inputs=inputs, outputs=model(inputs) , **kwargs)\n", | |
" def train_step(self, data):\n", | |
" self.lossobj.set_training_phase(True)\n", | |
" return super().train_step(data)\n", | |
" def test_step(self, data):\n", | |
" self.lossobj.set_training_phase(False)\n", | |
" return super().test_step(data)\n", | |
" def predict_step(self, data):\n", | |
" self.lossobj.set_training_phase(False)\n", | |
" return super().predict_step(data)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "yuk9ExNr0FQf" | |
}, | |
"source": [ | |
"import pickle\n", | |
"from tensorflow.keras import layers\n", | |
"\n", | |
"def wide_res_net( depth,width_factor, input_shape, num_outputs, top_dropout=0.0,l2_reg=0.0):\n", | |
" filters = [16, 1 * 16 * width_factor, 2 * 16 * width_factor, 4 * 16 * width_factor]\n", | |
" n = int((depth-4)/6)\n", | |
"\n", | |
" CONV_KWARGS = {\n", | |
" 'use_bias': False,\n", | |
" 'padding': 'same',\n", | |
" 'kernel_initializer': tf.keras.initializers.VarianceScaling(scale=2.0, mode='fan_out', distribution='normal'),\n", | |
" 'kernel_regularizer': tf.keras.regularizers.l2(l2_reg)\n", | |
" }\n", | |
" BN_KWARGS = {\n", | |
" 'momentum': 0.9,\n", | |
" 'epsilon' : 1e-5 \n", | |
" }\n", | |
" init_range = 1.0 / (num_outputs)**(0.5)\n", | |
" DENSE_KWARGS = {\n", | |
" 'kernel_initializer' : tf.keras.initializers.RandomUniform(minval=-init_range, maxval=init_range) , \n", | |
" 'kernel_regularizer': tf.keras.regularizers.l2(l2_reg)\n", | |
" }\n", | |
" def res_basic(x, filters, strides, pre_act=False):\n", | |
" shortcut = x\n", | |
" if pre_act:\n", | |
" x = layers.BatchNormalization( **BN_KWARGS )(x)\n", | |
" x = layers.ReLU()(x)\n", | |
" x = layers.Conv2D(filters, 3, strides,**CONV_KWARGS)(x)\n", | |
" x = layers.BatchNormalization( **BN_KWARGS )(x)\n", | |
" x = layers.ReLU()(x)\n", | |
" x = layers.Conv2D(filters, 3, 1,**CONV_KWARGS)(x)\n", | |
" x = layers.BatchNormalization( **BN_KWARGS )(x)\n", | |
" x = layers.ReLU()(x)\n", | |
" \n", | |
" shortcut_shape = K.int_shape(shortcut)\n", | |
" x_shape = K.int_shape(x)\n", | |
" downsample_strides = 2 if shortcut_shape[1] != x_shape[1] else 1\n", | |
" if downsample_strides != 1:\n", | |
" shortcut = layers.AveragePooling2D((downsample_strides,downsample_strides), downsample_strides)(shortcut)\n", | |
" if shortcut_shape[3] != x_shape[3]:\n", | |
" shortcut = tf.pad(shortcut, [ [0,0],[0,0],[0,0],[0,x_shape[3]-shortcut_shape[3]] ] )\n", | |
"\n", | |
" return shortcut+x\n", | |
"\n", | |
" x = inputs = layers.Input(shape=input_shape)\n", | |
" x = layers.Conv2D(filters[0], 3, 1, **CONV_KWARGS)(x)\n", | |
" for i in range(n):\n", | |
" x = res_basic(x, filters[1], strides=1, pre_act = (i == 0) )\n", | |
" for i in range(n):\n", | |
" x = res_basic(x, filters[2], strides=2 if i == 0 else 1 )\n", | |
" for i in range(n):\n", | |
" x = res_basic(x, filters[3], strides=2 if i == 0 else 1 )\n", | |
" x = layers.BatchNormalization( **BN_KWARGS )(x)\n", | |
" x = layers.ReLU()(x)\n", | |
" x = layers.GlobalAveragePooling2D()(x)\n", | |
" x = layers.Dropout(top_dropout)(x)\n", | |
" x = layers.Dense(units=num_outputs, **DENSE_KWARGS)(x)\n", | |
" x = layers.Activation('softmax')(x)\n", | |
" model = tf.keras.models.Model(inputs,x) \n", | |
" model.build(input_shape)\n", | |
" return model\n", | |
"\n", | |
"def train(model, optimizer, callbacks, epochs, ds_train, ds_validation, verbose=1, label_smoothing=0.0, loss=None):\n", | |
" \n", | |
" if loss == None:\n", | |
" loss = tf.keras.losses.CategoricalCrossentropy(from_logits=False, label_smoothing=label_smoothing)\n", | |
"\n", | |
" acc = tf.keras.metrics.CategoricalAccuracy(name='acc')\n", | |
" model.compile(optimizer=optimizer,\n", | |
" loss=loss,\n", | |
" metrics=[acc], steps_per_execution=10)\n", | |
" history = model.fit(ds_train,epochs=epochs,\n", | |
" validation_data=ds_validation,verbose=verbose, callbacks=callbacks)\n", | |
" \n", | |
" return history.history\n", | |
"\n", | |
"def train_main():\n", | |
" wrn_depth = 22\n", | |
" wrn_width_factor = 8\n", | |
"\n", | |
" num_classes = 10\n", | |
"\n", | |
" warmup_epochs = 10\n", | |
" cooldown_epochs = 190\n", | |
" EPOCHS = warmup_epochs + cooldown_epochs\n", | |
"\n", | |
" BATCH_SIZE = 512\n", | |
"\n", | |
" cutout_size = 0.5\n", | |
"\n", | |
" def scheduler(epoch, lr):\n", | |
" flat_epochs = EPOCHS - warmup_epochs - cooldown_epochs\n", | |
" if epoch < warmup_epochs:\n", | |
" return min_lr + (max_lr-min_lr)*epoch/warmup_epochs\n", | |
" elif epoch < warmup_epochs+flat_epochs:\n", | |
" return max_lr\n", | |
" else:\n", | |
" epoch = epoch - (warmup_epochs+flat_epochs) + 1\n", | |
" return min_lr + 0.5*(max_lr-min_lr)*(1.0+np.cos(epoch/cooldown_epochs*np.pi))\n", | |
" class CustomLogger(tf.keras.callbacks.Callback):\n", | |
" def __init__(self, lossobj=None):\n", | |
" self.lossobj = lossobj\n", | |
"\n", | |
" def on_epoch_end(self, epoch, logs):\n", | |
" lr = K.get_value(self.model.optimizer.lr)\n", | |
" logs['lr'] = lr\n", | |
" if self.lossobj!=None:\n", | |
" matrix = self.lossobj.matrix\n", | |
" print(matrix)\n", | |
" \n", | |
" test_list = [\n", | |
" # {'optim':'sgd', 'max_lr':1e-1, 'min_lr':1e-3, 'wd':5e-4, 'do':0.0, 'aug_type':'cutout', 'probability':1.0, 'ls':0.0},\n", | |
" # {'optim':'sgd', 'max_lr':1e-1, 'min_lr':1e-3, 'wd':5e-4, 'do':0.0, 'aug_type':'cutout', 'probability':1.0, 'ls':0.1},\n", | |
" {'optim':'sgd', 'max_lr':1e-1, 'min_lr':1e-3, 'wd':5e-4, 'do':0.0, 'aug_type':'cutout', 'probability':1.0, 'ls':-1.0},\n", | |
" ]\n", | |
"\n", | |
" for test in test_list:\n", | |
" optim = test['optim']\n", | |
" max_lr = test['max_lr']\n", | |
" min_lr = test['min_lr']\n", | |
" aug_type = test['aug_type']\n", | |
" probability = test['probability']\n", | |
" wd = test['wd']\n", | |
" dropout = test['do']\n", | |
" label_smoothing = test['ls']\n", | |
" if optim=='sgd':\n", | |
" optimizer = tf.optimizers.SGD(momentum=0.9, global_clipnorm=5.0)\n", | |
" elif optim=='adam':\n", | |
" optimizer = tf.optimizers.Adam(epsilon=1e-4, global_clipnorm=5.0)\n", | |
"\n", | |
" lr_scheduler = tf.keras.callbacks.LearningRateScheduler(scheduler)\n", | |
"\n", | |
" ds_train, ds_validation = make_dataset( train_data, validation_data, BATCH_SIZE, num_classes, aug_type=aug_type, probability=probability, cutout_size=cutout_size)\n", | |
"\n", | |
" train_model = wide_res_net( wrn_depth, wrn_width_factor, (32,32,3), num_classes, dropout, l2_reg=wd)\n", | |
" if label_smoothing<0:\n", | |
" train_model = OLSModel( train_model, num_classes, int(50000/BATCH_SIZE) )\n", | |
" lossobj = train_model.lossobj\n", | |
" else:\n", | |
" lossobj = None\n", | |
"\n", | |
" logger = CustomLogger()\n", | |
" start_time = time.time()\n", | |
" start_datetime = datetime.datetime.now(tz=tz)\n", | |
" print('Start:',start_datetime.isoformat())\n", | |
" print( optim, 'LR:',max_lr, 'WD:', wd,'AUG:', aug_type, 'LS:', label_smoothing )\n", | |
" history = train( train_model, optimizer, [lr_scheduler,logger], EPOCHS,\n", | |
" ds_train, ds_validation, \n", | |
" verbose=1, label_smoothing=label_smoothing, loss=lossobj)\n", | |
" end_datetime = datetime.datetime.now(tz=tz)\n", | |
" print('End:',end_datetime.strftime(\"%Y/%m/%d %H:%M:%S\"))\n", | |
" print('Time:', time.time()-start_time )\n", | |
"\n", | |
" if lossobj!=None:\n", | |
" print(lossobj.soft_label)\n", | |
"\n", | |
" showHistory(history)\n", | |
" f = open( f'{end_datetime.strftime(\"%Y%m%d%H%M%S\")}.hist', 'wb')\n", | |
" pickle.dump(history, f)\n", | |
" train_model.save_weights(\"weights.h5\")\n", | |
"\n", | |
"try:\n", | |
" tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection\n", | |
" tf.config.experimental_connect_to_cluster(tpu)\n", | |
" tf.tpu.experimental.initialize_tpu_system(tpu)\n", | |
" strategy = tf.distribute .TPUStrategy(tpu)\n", | |
"except ValueError:\n", | |
" strategy = tf.distribute.get_strategy()\n", | |
"\n", | |
"with strategy.scope():\n", | |
" train_main()\n" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "7JRS4RuVXs2N" | |
}, | |
"source": [ | |
"" | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment