Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import gc\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import numpy as np\n",
"\n",
"from sklearn.datasets import make_classification"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1024,)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X, y = make_classification(\n",
" n_samples=1024, \n",
" n_features=256, \n",
" n_informative=128, \n",
" n_redundant=0, \n",
" n_repeated=0, \n",
" n_classes=2, \n",
" n_clusters_per_class=2, \n",
" flip_y=0.01, \n",
" class_sep=1.0, \n",
" hypercube=True, \n",
" shuffle=True, \n",
" random_state=42\n",
")\n",
"y.shape"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def get_model(swish_module):\n",
" # Deliberately make the model very large\n",
" width = 2 ** 19\n",
" return nn.Sequential(\n",
" nn.Linear(256, width),\n",
" swish_module(),\n",
" nn.BatchNorm1d(width),\n",
" nn.Linear(width, 1)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"criterion = nn.BCEWithLogitsLoss()\n",
"batch_size = 128"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def print_parameter_count(model):\n",
" print(\"# of parameters: {:,d}\".format(\n",
" np.sum(list(p.numel() for p in model.parameters()))))\n",
" print(\"# of trainable parameters: {:,d}\".format(\n",
" np.sum(list(p.numel() for p in model.parameters() if p.requires_grad)))) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Plain Swish Version"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class PlainSwish(nn.Module):\n",
" def forward(self, input_tensor):\n",
" return input_tensor * torch.sigmoid(input_tensor)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# of parameters: 136,314,881\n",
"# of trainable parameters: 136,314,881\n"
]
},
{
"data": {
"text/plain": [
"524.0009765625"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = get_model(PlainSwish).cuda()\n",
"print_parameter_count(model)\n",
"optimizer = torch.optim.SGD(model.parameters(), lr=0.1)\n",
"optimizer.zero_grad()\n",
"torch.cuda.memory_allocated() / 1024 ** 2"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"data: 524.12646484375\n",
"forw: 1552.126953125\n",
"loss: 1552.12744140625\n",
"back: 1044.1279296875\n",
"step: 1044.1279296875\n",
"====================\n",
"data: 1044.1279296875\n",
"forw: 2072.1279296875\n",
"loss: 2072.1279296875\n",
"back: 1044.1279296875\n",
"step: 1044.1279296875\n",
"====================\n",
"data: 1044.1279296875\n",
"forw: 2072.1279296875\n",
"loss: 2072.1279296875\n",
"back: 1044.1279296875\n",
"step: 1044.1279296875\n",
"====================\n",
"data: 1044.1279296875\n",
"forw: 2072.1279296875\n",
"loss: 2072.1279296875\n",
"back: 1044.1279296875\n",
"step: 1044.1279296875\n",
"====================\n",
"data: 1044.1279296875\n",
"forw: 2072.1279296875\n",
"loss: 2072.1279296875\n",
"back: 1044.1279296875\n",
"step: 1044.1279296875\n",
"====================\n",
"data: 1044.1279296875\n",
"forw: 2072.1279296875\n",
"loss: 2072.1279296875\n",
"back: 1044.1279296875\n",
"step: 1044.1279296875\n",
"====================\n",
"data: 1044.1279296875\n",
"forw: 2072.1279296875\n",
"loss: 2072.1279296875\n",
"back: 1044.1279296875\n",
"step: 1044.1279296875\n",
"====================\n",
"data: 1044.1279296875\n",
"forw: 2072.1279296875\n",
"loss: 2072.1279296875\n",
"back: 1044.1279296875\n",
"step: 1044.1279296875\n",
"====================\n"
]
}
],
"source": [
"for i in range(0, 1024, batch_size):\n",
" Xt, yt = torch.tensor(X[i:i+batch_size], dtype=torch.float).cuda(), torch.tensor(y[i:i+batch_size], dtype=torch.float).cuda()\n",
" print(\"data:\", torch.cuda.memory_allocated() / 1024 ** 2)\n",
" pred = model(Xt)[:, 0]\n",
" print(\"forw:\", torch.cuda.memory_allocated() / 1024 ** 2)\n",
" loss = criterion(pred, yt)\n",
" # print(loss)\n",
" print(\"loss:\", torch.cuda.memory_allocated() / 1024 ** 2)\n",
" loss.backward()\n",
" print(\"back:\", torch.cuda.memory_allocated() / 1024 ** 2)\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
" print(\"step:\", torch.cuda.memory_allocated() / 1024 ** 2)\n",
" print(\"=\" * 20)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.0"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"del optimizer, model, Xt, yt, loss, pred\n",
"gc.collect()\n",
"torch.cuda.memory_allocated() / 1024"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Custom Swith Version\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"class Swish(torch.autograd.Function):\n",
" @staticmethod\n",
" def forward(ctx, i):\n",
" result = i * torch.sigmoid(i)\n",
" ctx.save_for_backward(i)\n",
" return result\n",
"\n",
" @staticmethod\n",
" def backward(ctx, grad_output):\n",
" i = ctx.saved_variables[0]\n",
" sigmoid_i = torch.sigmoid(i)\n",
" return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))\n",
" \n",
"class CustomSwish(nn.Module):\n",
" def forward(self, input_tensor):\n",
" return Swish.apply(input_tensor)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"536577.0"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = get_model(CustomSwish).cuda()\n",
"optimizer = torch.optim.SGD(model.parameters(), lr=0.1)\n",
"optimizer.zero_grad()\n",
"torch.cuda.memory_allocated() / 1024"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"data: 524.12646484375\n",
"forw: 1296.126953125\n",
"loss: 1296.12744140625\n",
"back: 1044.1279296875\n",
"step: 1044.1279296875\n",
"====================\n",
"data: 1044.1279296875\n",
"forw: 1816.1279296875\n",
"loss: 1816.1279296875\n",
"back: 1044.1279296875\n",
"step: 1044.1279296875\n",
"====================\n",
"data: 1044.1279296875\n",
"forw: 1816.1279296875\n",
"loss: 1816.1279296875\n",
"back: 1044.1279296875\n",
"step: 1044.1279296875\n",
"====================\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ceshine/miniconda3/envs/deep/lib/python3.7/site-packages/ipykernel_launcher.py:10: DeprecationWarning: 'saved_variables' is deprecated; use 'saved_tensors'\n",
" # Remove the CWD from sys.path while we load stuff.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"data: 1044.1279296875\n",
"forw: 1816.1279296875\n",
"loss: 1816.1279296875\n",
"back: 1044.1279296875\n",
"step: 1044.1279296875\n",
"====================\n",
"data: 1044.1279296875\n",
"forw: 1816.1279296875\n",
"loss: 1816.1279296875\n",
"back: 1044.1279296875\n",
"step: 1044.1279296875\n",
"====================\n",
"data: 1044.1279296875\n",
"forw: 1816.1279296875\n",
"loss: 1816.1279296875\n",
"back: 1044.1279296875\n",
"step: 1044.1279296875\n",
"====================\n",
"data: 1044.1279296875\n",
"forw: 1816.1279296875\n",
"loss: 1816.1279296875\n",
"back: 1044.1279296875\n",
"step: 1044.1279296875\n",
"====================\n",
"data: 1044.1279296875\n",
"forw: 1816.1279296875\n",
"loss: 1816.1279296875\n",
"back: 1044.1279296875\n",
"step: 1044.1279296875\n",
"====================\n"
]
}
],
"source": [
"for i in range(0, 1024, batch_size):\n",
" Xt, yt = torch.tensor(X[i:i+batch_size], dtype=torch.float).cuda(), torch.tensor(y[i:i+batch_size], dtype=torch.float).cuda()\n",
" print(\"data:\", torch.cuda.memory_allocated() / 1024 ** 2)\n",
" pred = model(Xt)[:, 0]\n",
" print(\"forw:\", torch.cuda.memory_allocated() / 1024 ** 2)\n",
" loss = criterion(pred, yt)\n",
" # print(loss)\n",
" print(\"loss:\", torch.cuda.memory_allocated() / 1024 ** 2)\n",
" loss.backward()\n",
" print(\"back:\", torch.cuda.memory_allocated() / 1024 ** 2)\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
" print(\"step:\", torch.cuda.memory_allocated() / 1024 ** 2)\n",
" print(\"=\" * 20)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.0"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"del optimizer, model, Xt, yt, loss, pred\n",
"gc.collect()\n",
"torch.cuda.memory_allocated() / 1024"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment