Skip to content

Instantly share code, notes, and snippets.

@sjchoi86
Created October 2, 2020 03:14
Show Gist options
  • Save sjchoi86/70a13df64fd9ba452ad94522b90418c0 to your computer and use it in GitHub Desktop.
Save sjchoi86/70a13df64fd9ba452ad94522b90418c0 to your computer and use it in GitHub Desktop.
Contrastive loss in SimCLR
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": "### Contrastive Loss for Self-supervised Learning"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import numpy as np\nimport tensorflow as tf\nnp.set_printoptions(precision=3)\nprint (\"Done.\")",
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": "Done.\n",
"name": "stdout"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Helper functions"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "def gpu_sess(): \n config = tf.ConfigProto(); \n config.gpu_options.allow_growth=True\n sess = tf.Session(config=config)\n return sess \ndef print_tf_tensor(sess,tf_tensor):\n tf_tensor_val = sess.run(tf_tensor)\n print (\"[%s] shape:%s\"%(tf_tensor.name,tf_tensor_val.shape))\n print (tf_tensor_val)",
"execution_count": 2,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Simple illustration of the original NCE loss\nhttps://github.com/google-research/simclr/blob/master/objective.py"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "tf.reset_default_graph()\nsess = gpu_sess()\nn_batch,dim = 5,2\nhidden_concat = tf.cast(tf.Variable(np.random.rand(n_batch*2,dim)),tf.float32,name='hidden_concat')\nsess.run(tf.global_variables_initializer())\nprint_tf_tensor(sess,hidden_concat)",
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"text": "[hidden_concat:0] shape:(10, 2)\n[[0.648 0.706]\n [0.509 0.362]\n [0.462 0.338]\n [0.925 0.73 ]\n [0.342 0.013]\n [0.308 0.826]\n [0.462 0.726]\n [0.47 0.977]\n [0.708 0.105]\n [0.326 0.858]]\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true,
"scrolled": true
},
"cell_type": "code",
"source": "# L2 normalize\nhidden_concat_nzd = tf.math.l2_normalize(hidden_concat,-1)\n# Split into half (which originally came from two separate feature maps)\nhidden1,hidden2 = tf.split(hidden_concat_nzd,num_or_size_splits=2,axis=0)\nprint_tf_tensor(sess,hidden_concat_nzd)\nprint_tf_tensor(sess,hidden1)\nprint_tf_tensor(sess,hidden2)",
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"text": "[l2_normalize:0] shape:(10, 2)\n[[0.676 0.737]\n [0.815 0.579]\n [0.807 0.59 ]\n [0.785 0.62 ]\n [0.999 0.037]\n [0.349 0.937]\n [0.537 0.844]\n [0.434 0.901]\n [0.989 0.147]\n [0.355 0.935]]\n[split:0] shape:(5, 2)\n[[0.676 0.737]\n [0.815 0.579]\n [0.807 0.59 ]\n [0.785 0.62 ]\n [0.999 0.037]]\n[split:1] shape:(5, 2)\n[[0.349 0.937]\n [0.537 0.844]\n [0.434 0.901]\n [0.989 0.147]\n [0.355 0.935]]\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# Label and mask\nhidden1_large = hidden1 # ?\nhidden2_large = hidden2 # ?\nlabels = tf.one_hot(tf.range(n_batch),n_batch*2,name='labels')\nmasks = tf.one_hot(tf.range(n_batch),n_batch,name='masks')\nprint_tf_tensor(sess,labels)\nprint_tf_tensor(sess,masks)",
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": "[labels:0] shape:(5, 10)\n[[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]]\n[masks:0] shape:(5, 5)\n[[1. 0. 0. 0. 0.]\n [0. 1. 0. 0. 0.]\n [0. 0. 1. 0. 0.]\n [0. 0. 0. 1. 0.]\n [0. 0. 0. 0. 1.]]\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# [G_aa]_{i,j} = cosdist(xa_i,xa_j) - LARGE_NUM*delta(xa_i,xa_j)\nLARGE_NUM = 1e9\ntemperature = 1.0\nlogits_aa = tf.matmul(hidden1, hidden1_large, transpose_b=True) / temperature\nlogits_aa = tf.subtract(logits_aa,LARGE_NUM*masks,name='logits_aa')\nprint_tf_tensor(sess,logits_aa)",
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"text": "[logits_aa:0] shape:(5, 5)\n[[-1.000e+09 9.779e-01 9.807e-01 9.873e-01 7.034e-01]\n [ 9.779e-01 -1.000e+09 9.999e-01 9.987e-01 8.364e-01]\n [ 9.807e-01 9.999e-01 -1.000e+09 9.993e-01 8.288e-01]\n [ 9.873e-01 9.987e-01 9.993e-01 -1.000e+09 8.075e-01]\n [ 7.034e-01 8.364e-01 8.288e-01 8.075e-01 -1.000e+09]]\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# [G_bb]_{i,j} = cosdist(xb_i,xb_j) - LARGE_NUM*delta(xb_i,xb_j)\nlogits_bb = tf.matmul(hidden2, hidden2_large, transpose_b=True) / temperature\nlogits_bb = tf.subtract(logits_bb,LARGE_NUM*masks,name='logits_bb')\nprint_tf_tensor(sess,logits_bb)",
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": "[logits_bb:0] shape:(5, 5)\n[[-1.000e+09 9.780e-01 9.958e-01 4.834e-01 1.000e+00]\n [ 9.780e-01 -1.000e+09 9.930e-01 6.552e-01 9.794e-01]\n [ 9.958e-01 9.930e-01 -1.000e+09 5.616e-01 9.964e-01]\n [ 4.834e-01 6.552e-01 5.616e-01 -1.000e+09 4.891e-01]\n [ 1.000e+00 9.794e-01 9.964e-01 4.891e-01 -1.000e+09]]\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# [G_ab]_{i,j} = cosdist(xa_i,xb_j)\nlogits_ab = tf.matmul(hidden1, hidden2_large, transpose_b=True) / temperature\nlogits_ba = tf.matmul(hidden2, hidden1_large, transpose_b=True) / temperature\nprint_tf_tensor(sess,logits_ab)\nprint_tf_tensor(sess,logits_ba)",
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"text": "[truediv_2:0] shape:(5, 5)\n[[0.926 0.985 0.957 0.777 0.929]\n [0.827 0.926 0.875 0.892 0.831]\n [0.835 0.931 0.882 0.885 0.838]\n [0.855 0.944 0.899 0.868 0.858]\n [0.384 0.568 0.467 0.994 0.39 ]]\n[truediv_3:0] shape:(5, 5)\n[[0.926 0.827 0.835 0.855 0.384]\n [0.985 0.926 0.931 0.944 0.568]\n [0.957 0.875 0.882 0.899 0.467]\n [0.777 0.892 0.885 0.868 0.994]\n [0.929 0.831 0.838 0.858 0.39 ]]\n",
"name": "stdout"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### $\\text{Given } \\{ x^a_i, x^b_i \\}_{i=1}^N \\text{ where }x^a_i \\text{ and } x^b_i \\text{ has a correspondence.}$"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### $\\quad \\text{loss_a} = \\sum_{i=1}^N \n\\left(\n\\log \\frac{\\exp( sim(x^a_i,~x^b_i) )}\n{\\sum_{k=1}^K \\exp( sim(x^a_i,~x^b_k)) ~+~ \\sum_{k=1, k \\neq i}^K \\exp( sim(x^a_i,~x^a_k)) } \n\\right) $"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### $\\quad \\text{loss_b} = \\sum_{i=1}^N \n\\left(\n\\log \\frac{\\exp( sim(x^b_i,~x^a_i) )}\n{\\sum_{k=1}^K \\exp( sim(x^b_i,~x^a_k)) ~+~ \\sum_{k=1, k \\neq i}^K \\exp( sim(x^b_i,~x^b_k)) } \n\\right) $"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### $\\quad \\text{loss} = \\text{loss_a} + \\text{loss_b}$"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "print_tf_tensor(sess,labels)",
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"text": "[labels:0] shape:(5, 10)\n[[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]]\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# Define the nce loss function\nweights = 1.0\nloss_a = tf.losses.softmax_cross_entropy(\n labels, tf.concat([logits_ab, logits_aa], 1), weights=weights)\nloss_b = tf.losses.softmax_cross_entropy(\n labels, tf.concat([logits_ba, logits_bb], 1), weights=weights)\nloss = loss_a + loss_b\nprint (loss)",
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"text": "Tensor(\"add:0\", shape=(), dtype=float32)\n",
"name": "stdout"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Wrap it up with a function"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "def get_nce_loss(hidden_concat):\n # L2 normalize\n hidden_concat_nzd = tf.math.l2_normalize(hidden_concat,-1)\n # Split into half (which originally came from two separate feature maps)\n hidden1,hidden2 = tf.split(hidden_concat_nzd,num_or_size_splits=2,axis=0)\n # Label and mask\n hidden1_large = hidden1 \n hidden2_large = hidden2 \n # [G_aa]_{i,j} = cosdist(xa_i,xa_j) - LARGE_NUM*delta(xa_i,xa_j)\n LARGE_NUM = 1e9\n temperature = 1.0\n masks = tf.one_hot(tf.range(n_batch),n_batch,name='masks')\n logits_aa = tf.matmul(hidden1, hidden1_large, transpose_b=True) / temperature\n logits_aa = tf.subtract(logits_aa,LARGE_NUM*masks,name='logits_aa')\n # [G_bb]_{i,j} = cosdist(xb_i,xb_j) - LARGE_NUM*delta(xb_i,xb_j)\n logits_bb = tf.matmul(hidden2, hidden2_large, transpose_b=True) / temperature\n logits_bb = tf.subtract(logits_bb,LARGE_NUM*masks,name='logits_bb')\n # [G_ab]_{i,j} = cosdist(xa_i,xb_j)\n logits_ab = tf.matmul(hidden1, hidden2_large, transpose_b=True) / temperature\n logits_ba = tf.matmul(hidden2, hidden1_large, transpose_b=True) / temperature\n # Define the nce loss function\n labels = tf.one_hot(tf.range(n_batch),n_batch*2,name='labels')\n weights = 1.0\n loss_a = tf.losses.softmax_cross_entropy(\n labels, tf.concat([logits_ab, logits_aa], 1), weights=weights)\n loss_b = tf.losses.softmax_cross_entropy(\n labels, tf.concat([logits_ba, logits_bb], 1), weights=weights)\n nce_loss = loss_a + loss_b\n return nce_loss\nprint (\"Done.\")",
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"text": "Done.\n",
"name": "stdout"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Usage 1 (random feature maps)"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "tf.reset_default_graph()\nsess = gpu_sess()\nn_batch,dim = 5,2\nh1 = tf.cast(tf.Variable(np.random.rand(n_batch,dim)),tf.float32,name='h1')\nh2 = tf.cast(tf.Variable(np.random.rand(n_batch,dim)),tf.float32,name='h2')\nhidden_concat = tf.concat([h1,h2],axis=0)\nnce_loss = get_nce_loss(hidden_concat)\nsess.run(tf.global_variables_initializer())\nprint_tf_tensor(sess,h1)\nprint_tf_tensor(sess,h2)\nprint (\"Loss is [%.4f].\"%(sess.run(nce_loss)))",
"execution_count": 32,
"outputs": [
{
"output_type": "stream",
"text": "[h1:0] shape:(5, 2)\n[[0.283 0.299]\n [0.783 0.863]\n [0.334 0.133]\n [0.878 0.516]\n [0.95 0.637]]\n[h2:0] shape:(5, 2)\n[[0.858 0.817]\n [0.811 0.107]\n [0.765 0.798]\n [0.764 0.56 ]\n [0.515 0.597]]\nLoss is [4.4372].\n",
"name": "stdout"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Usage 2 (similar feature maps)"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "tf.reset_default_graph()\nsess = gpu_sess()\nn_batch,dim = 5,2\nbias = np.random.randn(n_batch,dim)\neps = 0.1\nh1 = tf.cast(tf.Variable(bias+eps*np.random.randn(n_batch,dim)),tf.float32,name='h1')\nh2 = tf.cast(tf.Variable(bias+eps*np.random.randn(n_batch,dim)),tf.float32,name='h2')\nhidden_concat = tf.concat([h1,h2],axis=0)\nnce_loss = get_nce_loss(hidden_concat)\nsess.run(tf.global_variables_initializer())\nprint_tf_tensor(sess,h1)\nprint_tf_tensor(sess,h2)\nprint (\"Loss is [%.4f].\"%(sess.run(nce_loss)))",
"execution_count": 41,
"outputs": [
{
"output_type": "stream",
"text": "[h1:0] shape:(5, 2)\n[[-0.678 -0.31 ]\n [-0.706 1.504]\n [ 0.645 0.913]\n [-1.351 -1.435]\n [ 0.041 -0.077]]\n[h2:0] shape:(5, 2)\n[[-0.816 -0.507]\n [-0.865 1.495]\n [ 0.625 0.879]\n [-1.352 -1.448]\n [ 0.281 0.041]]\nLoss is [2.9659].\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.6.7",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"gist": {
"id": "a8b7a656ea4dbd4bb94bec8669280ac5",
"data": {
"description": "Contrastive loss in SimCLR ",
"public": true
}
},
"_draft": {
"nbviewer_url": "https://gist.github.com/a8b7a656ea4dbd4bb94bec8669280ac5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment