Created
February 12, 2020 08:11
-
-
Save MercuriXito/aec10235ab76b3d5b723fc4c3b50e4fc to your computer and use it in GitHub Desktop.
Spectral Norm.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "Spectral Norm.ipynb", | |
"provenance": [], | |
"authorship_tag": "ABX9TyM3SmYIAH/9dnq2u36/0rcf", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/MercuriXito/aec10235ab76b3d5b723fc4c3b50e4fc/spectral-norm.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "rVoqoR3sY2tH", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import numpy as np" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "x_BMEAoSZaHE", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def l2norm(W):\n", | |
" return W / np.linalg.norm(W)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "O2w2wCMNY5F_", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def spectral_norm(W,power_iteration = 1, u = None, return_u = False):\n", | |
" m, n = W.shape\n", | |
" if u is None:\n", | |
" u = np.random.randn(m)\n", | |
" for _ in range(power_iteration):\n", | |
" v = l2norm(np.matmul(W.T,u))\n", | |
" u = l2norm(np.matmul(W,v))\n", | |
"\n", | |
" norm = u.T.dot(W.dot(v))\n", | |
" print(norm)\n", | |
" W = W / norm\n", | |
" \n", | |
" if return_u:\n", | |
" return W,u\n", | |
" return W" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "rH1IPLJWZ5uS", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"weight = np.random.randn(3,32,3,3) # usual weight of Conv Layer" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "YTsvGHVpaqUc", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"卷积层权重进行Spectral Normalization时,(${C_{in} \\times C_{out} \\times W \\times H}$) reshape 成二维矩阵 ${C_{out} \\times (C_{in} \\times W \\times H)}$" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "gV-JHAkDaJc8", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"m = weight.shape[1] \n", | |
"W = weight.reshape(m,-1)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "LkcVKzsNaQbD", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 107 | |
}, | |
"outputId": "4a06a33d-a7b6-4f85-d717-6b825d83398b" | |
}, | |
"source": [ | |
"W = spectral_norm(W,5,return_u=True)" | |
], | |
"execution_count": 11, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"7.831560205837326\n", | |
"1.0884686577919906\n", | |
"1.0306490768192758\n", | |
"1.019976863211998\n", | |
"1.0143304011361334\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Mq8J9kWNbkt8", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"经过一次迭代,Spectral Norm 就可以到1,十分高效。\n", | |
"\n", | |
"Pytorch已经支持。\n", | |
"\n", | |
"torch.nn.spectral_norm() 和 torch.nn.remove_spectral_norm() 两个函数,参数是nn.Module类型,详细[代码](https://pytorch.org/docs/stable/_modules/torch/nn/utils/spectral_norm.html#spectral_norm)。" | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment