Created
October 2, 2020 03:14
-
-
Save sjchoi86/70a13df64fd9ba452ad94522b90418c0 to your computer and use it in GitHub Desktop.
Contrastive loss in SimCLR
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "### Contrastive Loss for Self-supervised Learning" | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "import numpy as np\nimport tensorflow as tf\nnp.set_printoptions(precision=3)\nprint (\"Done.\")", | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "Done.\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "### Helper functions" | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "def gpu_sess(): \n config = tf.ConfigProto(); \n config.gpu_options.allow_growth=True\n sess = tf.Session(config=config)\n return sess \ndef print_tf_tensor(sess,tf_tensor):\n tf_tensor_val = sess.run(tf_tensor)\n print (\"[%s] shape:%s\"%(tf_tensor.name,tf_tensor_val.shape))\n print (tf_tensor_val)", | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "### Simple illustration of the original NCE loss\nhttps://github.com/google-research/simclr/blob/master/objective.py" | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "tf.reset_default_graph()\nsess = gpu_sess()\nn_batch,dim = 5,2\nhidden_concat = tf.cast(tf.Variable(np.random.rand(n_batch*2,dim)),tf.float32,name='hidden_concat')\nsess.run(tf.global_variables_initializer())\nprint_tf_tensor(sess,hidden_concat)", | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "[hidden_concat:0] shape:(10, 2)\n[[0.648 0.706]\n [0.509 0.362]\n [0.462 0.338]\n [0.925 0.73 ]\n [0.342 0.013]\n [0.308 0.826]\n [0.462 0.726]\n [0.47 0.977]\n [0.708 0.105]\n [0.326 0.858]]\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true, | |
"scrolled": true | |
}, | |
"cell_type": "code", | |
"source": "# L2 normalize\nhidden_concat_nzd = tf.math.l2_normalize(hidden_concat,-1)\n# Split into half (which originally came from two separate feature maps)\nhidden1,hidden2 = tf.split(hidden_concat_nzd,num_or_size_splits=2,axis=0)\nprint_tf_tensor(sess,hidden_concat_nzd)\nprint_tf_tensor(sess,hidden1)\nprint_tf_tensor(sess,hidden2)", | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "[l2_normalize:0] shape:(10, 2)\n[[0.676 0.737]\n [0.815 0.579]\n [0.807 0.59 ]\n [0.785 0.62 ]\n [0.999 0.037]\n [0.349 0.937]\n [0.537 0.844]\n [0.434 0.901]\n [0.989 0.147]\n [0.355 0.935]]\n[split:0] shape:(5, 2)\n[[0.676 0.737]\n [0.815 0.579]\n [0.807 0.59 ]\n [0.785 0.62 ]\n [0.999 0.037]]\n[split:1] shape:(5, 2)\n[[0.349 0.937]\n [0.537 0.844]\n [0.434 0.901]\n [0.989 0.147]\n [0.355 0.935]]\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "# Label and mask\nhidden1_large = hidden1 # ?\nhidden2_large = hidden2 # ?\nlabels = tf.one_hot(tf.range(n_batch),n_batch*2,name='labels')\nmasks = tf.one_hot(tf.range(n_batch),n_batch,name='masks')\nprint_tf_tensor(sess,labels)\nprint_tf_tensor(sess,masks)", | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "[labels:0] shape:(5, 10)\n[[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]]\n[masks:0] shape:(5, 5)\n[[1. 0. 0. 0. 0.]\n [0. 1. 0. 0. 0.]\n [0. 0. 1. 0. 0.]\n [0. 0. 0. 1. 0.]\n [0. 0. 0. 0. 1.]]\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "# [G_aa]_{i,j} = cosdist(xa_i,xa_j) - LARGE_NUM*delta(xa_i,xa_j)\nLARGE_NUM = 1e9\ntemperature = 1.0\nlogits_aa = tf.matmul(hidden1, hidden1_large, transpose_b=True) / temperature\nlogits_aa = tf.subtract(logits_aa,LARGE_NUM*masks,name='logits_aa')\nprint_tf_tensor(sess,logits_aa)", | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "[logits_aa:0] shape:(5, 5)\n[[-1.000e+09 9.779e-01 9.807e-01 9.873e-01 7.034e-01]\n [ 9.779e-01 -1.000e+09 9.999e-01 9.987e-01 8.364e-01]\n [ 9.807e-01 9.999e-01 -1.000e+09 9.993e-01 8.288e-01]\n [ 9.873e-01 9.987e-01 9.993e-01 -1.000e+09 8.075e-01]\n [ 7.034e-01 8.364e-01 8.288e-01 8.075e-01 -1.000e+09]]\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "# [G_bb]_{i,j} = cosdist(xb_i,xb_j) - LARGE_NUM*delta(xb_i,xb_j)\nlogits_bb = tf.matmul(hidden2, hidden2_large, transpose_b=True) / temperature\nlogits_bb = tf.subtract(logits_bb,LARGE_NUM*masks,name='logits_bb')\nprint_tf_tensor(sess,logits_bb)", | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "[logits_bb:0] shape:(5, 5)\n[[-1.000e+09 9.780e-01 9.958e-01 4.834e-01 1.000e+00]\n [ 9.780e-01 -1.000e+09 9.930e-01 6.552e-01 9.794e-01]\n [ 9.958e-01 9.930e-01 -1.000e+09 5.616e-01 9.964e-01]\n [ 4.834e-01 6.552e-01 5.616e-01 -1.000e+09 4.891e-01]\n [ 1.000e+00 9.794e-01 9.964e-01 4.891e-01 -1.000e+09]]\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "# [G_ab]_{i,j} = cosdist(xa_i,xb_j)\nlogits_ab = tf.matmul(hidden1, hidden2_large, transpose_b=True) / temperature\nlogits_ba = tf.matmul(hidden2, hidden1_large, transpose_b=True) / temperature\nprint_tf_tensor(sess,logits_ab)\nprint_tf_tensor(sess,logits_ba)", | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "[truediv_2:0] shape:(5, 5)\n[[0.926 0.985 0.957 0.777 0.929]\n [0.827 0.926 0.875 0.892 0.831]\n [0.835 0.931 0.882 0.885 0.838]\n [0.855 0.944 0.899 0.868 0.858]\n [0.384 0.568 0.467 0.994 0.39 ]]\n[truediv_3:0] shape:(5, 5)\n[[0.926 0.827 0.835 0.855 0.384]\n [0.985 0.926 0.931 0.944 0.568]\n [0.957 0.875 0.882 0.899 0.467]\n [0.777 0.892 0.885 0.868 0.994]\n [0.929 0.831 0.838 0.858 0.39 ]]\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "### $\\text{Given } \\{ x^a_i, x^b_i \\}_{i=1}^N \\text{ where }x^a_i \\text{ and } x^b_i \\text{ has a correspondence.}$" | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "### $\\quad \\text{loss_a} = \\sum_{i=1}^N \n\\left(\n\\log \\frac{\\exp( sim(x^a_i,~x^b_i) )}\n{\\sum_{k=1}^K \\exp( sim(x^a_i,~x^b_k)) ~+~ \\sum_{k=1, k \\neq i}^K \\exp( sim(x^a_i,~x^a_k)) } \n\\right) $" | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "### $\\quad \\text{loss_b} = \\sum_{i=1}^N \n\\left(\n\\log \\frac{\\exp( sim(x^b_i,~x^a_i) )}\n{\\sum_{k=1}^K \\exp( sim(x^b_i,~x^a_k)) ~+~ \\sum_{k=1, k \\neq i}^K \\exp( sim(x^b_i,~x^b_k)) } \n\\right) $" | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "### $\\quad \\text{loss} = \\text{loss_a} + \\text{loss_b}$" | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "print_tf_tensor(sess,labels)", | |
"execution_count": 9, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "[labels:0] shape:(5, 10)\n[[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]]\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "# Define the nce loss function\nweights = 1.0\nloss_a = tf.losses.softmax_cross_entropy(\n labels, tf.concat([logits_ab, logits_aa], 1), weights=weights)\nloss_b = tf.losses.softmax_cross_entropy(\n labels, tf.concat([logits_ba, logits_bb], 1), weights=weights)\nloss = loss_a + loss_b\nprint (loss)", | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "Tensor(\"add:0\", shape=(), dtype=float32)\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "### Wrap it up with a function" | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "def get_nce_loss(hidden_concat):\n # L2 normalize\n hidden_concat_nzd = tf.math.l2_normalize(hidden_concat,-1)\n # Split into half (which originally came from two separate feature maps)\n hidden1,hidden2 = tf.split(hidden_concat_nzd,num_or_size_splits=2,axis=0)\n # Label and mask\n hidden1_large = hidden1 \n hidden2_large = hidden2 \n # [G_aa]_{i,j} = cosdist(xa_i,xa_j) - LARGE_NUM*delta(xa_i,xa_j)\n LARGE_NUM = 1e9\n temperature = 1.0\n masks = tf.one_hot(tf.range(n_batch),n_batch,name='masks')\n logits_aa = tf.matmul(hidden1, hidden1_large, transpose_b=True) / temperature\n logits_aa = tf.subtract(logits_aa,LARGE_NUM*masks,name='logits_aa')\n # [G_bb]_{i,j} = cosdist(xb_i,xb_j) - LARGE_NUM*delta(xb_i,xb_j)\n logits_bb = tf.matmul(hidden2, hidden2_large, transpose_b=True) / temperature\n logits_bb = tf.subtract(logits_bb,LARGE_NUM*masks,name='logits_bb')\n # [G_ab]_{i,j} = cosdist(xa_i,xb_j)\n logits_ab = tf.matmul(hidden1, hidden2_large, transpose_b=True) / temperature\n logits_ba = tf.matmul(hidden2, hidden1_large, transpose_b=True) / temperature\n # Define the nce loss function\n labels = tf.one_hot(tf.range(n_batch),n_batch*2,name='labels')\n weights = 1.0\n loss_a = tf.losses.softmax_cross_entropy(\n labels, tf.concat([logits_ab, logits_aa], 1), weights=weights)\n loss_b = tf.losses.softmax_cross_entropy(\n labels, tf.concat([logits_ba, logits_bb], 1), weights=weights)\n nce_loss = loss_a + loss_b\n return nce_loss\nprint (\"Done.\")", | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "Done.\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "### Usage 1 (random feature maps)" | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "tf.reset_default_graph()\nsess = gpu_sess()\nn_batch,dim = 5,2\nh1 = tf.cast(tf.Variable(np.random.rand(n_batch,dim)),tf.float32,name='h1')\nh2 = tf.cast(tf.Variable(np.random.rand(n_batch,dim)),tf.float32,name='h2')\nhidden_concat = tf.concat([h1,h2],axis=0)\nnce_loss = get_nce_loss(hidden_concat)\nsess.run(tf.global_variables_initializer())\nprint_tf_tensor(sess,h1)\nprint_tf_tensor(sess,h2)\nprint (\"Loss is [%.4f].\"%(sess.run(nce_loss)))", | |
"execution_count": 32, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "[h1:0] shape:(5, 2)\n[[0.283 0.299]\n [0.783 0.863]\n [0.334 0.133]\n [0.878 0.516]\n [0.95 0.637]]\n[h2:0] shape:(5, 2)\n[[0.858 0.817]\n [0.811 0.107]\n [0.765 0.798]\n [0.764 0.56 ]\n [0.515 0.597]]\nLoss is [4.4372].\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "### Usage 2 (similar feature maps)" | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "tf.reset_default_graph()\nsess = gpu_sess()\nn_batch,dim = 5,2\nbias = np.random.randn(n_batch,dim)\neps = 0.1\nh1 = tf.cast(tf.Variable(bias+eps*np.random.randn(n_batch,dim)),tf.float32,name='h1')\nh2 = tf.cast(tf.Variable(bias+eps*np.random.randn(n_batch,dim)),tf.float32,name='h2')\nhidden_concat = tf.concat([h1,h2],axis=0)\nnce_loss = get_nce_loss(hidden_concat)\nsess.run(tf.global_variables_initializer())\nprint_tf_tensor(sess,h1)\nprint_tf_tensor(sess,h2)\nprint (\"Loss is [%.4f].\"%(sess.run(nce_loss)))", | |
"execution_count": 41, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "[h1:0] shape:(5, 2)\n[[-0.678 -0.31 ]\n [-0.706 1.504]\n [ 0.645 0.913]\n [-1.351 -1.435]\n [ 0.041 -0.077]]\n[h2:0] shape:(5, 2)\n[[-0.816 -0.507]\n [-0.865 1.495]\n [ 0.625 0.879]\n [-1.352 -1.448]\n [ 0.281 0.041]]\nLoss is [2.9659].\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "", | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3", | |
"language": "python" | |
}, | |
"language_info": { | |
"name": "python", | |
"version": "3.6.7", | |
"mimetype": "text/x-python", | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"pygments_lexer": "ipython3", | |
"nbconvert_exporter": "python", | |
"file_extension": ".py" | |
}, | |
"gist": { | |
"id": "a8b7a656ea4dbd4bb94bec8669280ac5", | |
"data": { | |
"description": "Contrastive loss in SimCLR ", | |
"public": true | |
} | |
}, | |
"_draft": { | |
"nbviewer_url": "https://gist.github.com/a8b7a656ea4dbd4bb94bec8669280ac5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment