Skip to content

Instantly share code, notes, and snippets.

@MercuriXito
Created February 12, 2020 08:11
Show Gist options
  • Save MercuriXito/aec10235ab76b3d5b723fc4c3b50e4fc to your computer and use it in GitHub Desktop.
Save MercuriXito/aec10235ab76b3d5b723fc4c3b50e4fc to your computer and use it in GitHub Desktop.
Spectral Norm.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Spectral Norm.ipynb",
"provenance": [],
"authorship_tag": "ABX9TyM3SmYIAH/9dnq2u36/0rcf",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/MercuriXito/aec10235ab76b3d5b723fc4c3b50e4fc/spectral-norm.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "rVoqoR3sY2tH",
"colab_type": "code",
"colab": {}
},
"source": [
"import numpy as np"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "x_BMEAoSZaHE",
"colab_type": "code",
"colab": {}
},
"source": [
"def l2norm(W):\n",
" return W / np.linalg.norm(W)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "O2w2wCMNY5F_",
"colab_type": "code",
"colab": {}
},
"source": [
"def spectral_norm(W,power_iteration = 1, u = None, return_u = False):\n",
" m, n = W.shape\n",
" if u is None:\n",
" u = np.random.randn(m)\n",
" for _ in range(power_iteration):\n",
" v = l2norm(np.matmul(W.T,u))\n",
" u = l2norm(np.matmul(W,v))\n",
"\n",
" norm = u.T.dot(W.dot(v))\n",
" print(norm)\n",
" W = W / norm\n",
" \n",
" if return_u:\n",
" return W,u\n",
" return W"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "rH1IPLJWZ5uS",
"colab_type": "code",
"colab": {}
},
"source": [
"weight = np.random.randn(3,32,3,3) # usual weight of Conv Layer"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "YTsvGHVpaqUc",
"colab_type": "text"
},
"source": [
"卷积层权重进行Spectral Normalization时,(${C_{in} \\times C_{out} \\times W \\times H}$) reshape 成二维矩阵 ${C_{out} \\times (C_{in} \\times W \\times H)}$"
]
},
{
"cell_type": "code",
"metadata": {
"id": "gV-JHAkDaJc8",
"colab_type": "code",
"colab": {}
},
"source": [
"m = weight.shape[1] \n",
"W = weight.reshape(m,-1)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "LkcVKzsNaQbD",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 107
},
"outputId": "4a06a33d-a7b6-4f85-d717-6b825d83398b"
},
"source": [
"W = spectral_norm(W,5,return_u=True)"
],
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"text": [
"7.831560205837326\n",
"1.0884686577919906\n",
"1.0306490768192758\n",
"1.019976863211998\n",
"1.0143304011361334\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Mq8J9kWNbkt8",
"colab_type": "text"
},
"source": [
"经过一次迭代,Spectral Norm 就可以到1,十分高效。\n",
"\n",
"Pytorch已经支持。\n",
"\n",
"torch.nn.spectral_norm() 和 torch.nn.remove_spectral_norm() 两个函数,参数是nn.Module类型,详细[代码](https://pytorch.org/docs/stable/_modules/torch/nn/utils/spectral_norm.html#spectral_norm)。"
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment