Skip to content

Instantly share code, notes, and snippets.

@refraction-ray
Created September 2, 2019 09:03
Show Gist options
  • Save refraction-ray/32ad9ce58d9bc332ef7a9d0dc537065a to your computer and use it in GitHub Desktop.
Save refraction-ray/32ad9ce58d9bc332ef7a9d0dc537065a to your computer and use it in GitHub Desktop.
Complex SVD backprop demo
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 109,
"metadata": {},
"outputs": [],
"source": [
"## just omit lines in this cell\n",
"import sys\n",
"sys.path.append(\"/home/ubuntu/spack/opt/spack/linux-ubuntu18.04-x86_64/gcc-7.4.0/python-3.6.5-63x2grpokc4ax6mgyfilhyjnm5ersc3w/lib/python3.6/site-packages\")\n",
"import os\n",
"os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = \"true\""
]
},
{
"cell_type": "code",
"execution_count": 110,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"\n",
"tf.enable_eager_execution()"
]
},
{
"cell_type": "code",
"execution_count": 111,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'1.13.1'"
]
},
"execution_count": 111,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tf.__version__"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [],
"source": [
"@tf.custom_gradient\n",
"def svd(A): # only valid for square matrix A\n",
" S, U, V = tf.svd(A)\n",
" def grad(*dy):\n",
" dS, dU, dV = dy\n",
" dAs = U@tf.diag(tf.cast(dS,dtype=tf.complex128))@tf.linalg.adjoint(V)\n",
" d = 1e-10\n",
" F = (S*S - (S*S)[:, None])\n",
" F = tf.cast(F,dtype=tf.complex128)\n",
" F = 1/(F+d)-tf.diag(tf.diag_part(1/(F+d)))\n",
" J = F*(tf.transpose(tf.conj(U))@dU)\n",
" dAu = U@(J+tf.transpose(tf.conj(J)))@tf.diag(tf.cast(S,dtype=tf.complex128))@tf.linalg.adjoint(V)\n",
" K = F*(tf.transpose(tf.conj(V))@dV)\n",
" dAv = U@tf.diag(tf.cast(S,dtype=tf.complex128))@(K+tf.transpose(tf.conj(K)))@tf.linalg.adjoint(V)\n",
" return dAv + dAu + dAs\n",
" return [S,U,V] , grad"
]
},
{
"cell_type": "code",
"execution_count": 105,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0.00942137-0.04090625j, 0.22673931+0.09762721j],\n",
" [-0.08745518+0.14342547j, -0.04197138-0.00069741j]])"
]
},
"execution_count": 105,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def loss(A):\n",
" S,U,V = svd(A)\n",
" m = tf.conj(U[0,0])*U[0,0]\n",
" return m\n",
"\n",
"def g(f, A):\n",
" with tf.GradientTape() as t:\n",
" t.watch(A)\n",
" y = f(A)\n",
" dy_dA = t.gradient(y, A)\n",
"\n",
" return dy_dA.numpy()\n",
"\n",
"A = tf.constant(np.array([[-1.+1.j,2.+1.j],[1.-2.j,3.+0.8j]]))\n",
"g(loss, A)"
]
},
{
"cell_type": "code",
"execution_count": 106,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(-0.08745518020834747+0.14342550236116547j)"
]
},
"execution_count": 106,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"da=tf.constant(np.array([[0.+0.j,0.],[1.,0.]]))\n",
"d=1e-6\n",
"\n",
"((loss(A+d*da)-loss(A))/d).numpy()+1j*((loss(A+d*1.j*da)- loss(A))/d).numpy()\n",
"## the result is exactly the same as AD"
]
},
{
"cell_type": "code",
"execution_count": 107,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0.11309833+0.05005293j, -0.03792837+0.11768604j],\n",
" [ 0.11768604+0.03792837j, 0.02540664-0.12104144j]])"
]
},
"execution_count": 107,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def loss2(A):\n",
" S,U,V = svd(A)\n",
" m = tf.real(tf.conj(V[0,0])*U[0,0])\n",
" return m\n",
"g(loss2, A)"
]
},
{
"cell_type": "code",
"execution_count": 108,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.12817347053162287+0.04221504339152471j)"
]
},
"execution_count": 108,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"da=tf.constant(np.array([[0.+0.j,0.],[1.,0.]]))\n",
"d=1e-6\n",
"\n",
"((loss2(A+d*da)-loss2(A))/d).numpy()+1j*((loss2(A+d*1.j*da)- loss2(A))/d).numpy()\n",
"## the inconsistence is beyond error bar"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment