Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save crcrpar/e3a01c2a2624be18d5ff1576c50dd632 to your computer and use it in GitHub Desktop.
Save crcrpar/e3a01c2a2624be18d5ff1576c50dd632 to your computer and use it in GitHub Desktop.
how-gradients-are-accumulated-in-real.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "how-gradients-are-accumulated-in-real.ipynb",
"provenance": [],
"authorship_tag": "ABX9TyP6EDCKWcBthpUpOUpB1EC9",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/crcrpar/e3a01c2a2624be18d5ff1576c50dd632/how-gradients-are-accumulated-in-real.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "23yq9ekpFo1p"
},
"source": [
"import torch\n",
"import torch.nn as nn"
],
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "zwem8NiTFsKx"
},
"source": [
"m = nn.Linear(16, 1)"
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "aP1RncLkFvss"
},
"source": [
"def stub(m: nn.Module):\n",
"\n",
" for _ in range(5):\n",
" x = torch.rand(32, 16, dtype=torch.float32)\n",
" y = torch.rand(1, 1, dtype=torch.float32)\n",
"\n",
" loss = torch.norm(y - m(x))\n",
" loss.backward()\n",
"\n",
" print(m.weight.grad)"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Af_RCQSGGAnY",
"outputId": "2b497550-75fa-4363-979a-c9be32a74ca7"
},
"source": [
"stub(m)"
],
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"text": [
"tensor([[-3.0965, -2.8518, -2.9717, -2.7008, -2.6433, -2.7176, -2.9091, -2.8171,\n",
" -3.1226, -2.8107, -2.4387, -2.4836, -2.5720, -3.1647, -2.0334, -2.7457]])\n",
"tensor([[-5.8239, -5.4552, -5.8949, -5.2432, -5.4305, -5.3250, -5.7898, -5.8785,\n",
" -5.6517, -5.7058, -5.1769, -4.9697, -4.8814, -5.9802, -4.7384, -5.7139]])\n",
"tensor([[-8.3533, -7.9277, -8.8049, -8.4401, -8.1791, -8.0246, -8.7390, -8.7580,\n",
" -9.0928, -8.2384, -7.7129, -7.7327, -7.6828, -8.6108, -7.1655, -8.2564]])\n",
"tensor([[-10.9206, -10.6197, -11.2812, -11.2026, -10.7852, -10.7390, -11.2756,\n",
" -11.8269, -12.1937, -10.9059, -10.3348, -10.1517, -10.1838, -11.1223,\n",
" -9.4941, -10.4348]])\n",
"tensor([[-13.6102, -13.2164, -14.4526, -13.9059, -13.5810, -13.3799, -14.1594,\n",
" -14.4318, -14.7769, -13.8465, -12.2982, -12.6835, -12.2972, -12.9532,\n",
" -11.4039, -13.0911]])\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "exBYWKXaGBNA"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment