Skip to content

Instantly share code, notes, and snippets.

@geffy
Created August 24, 2017 07:08
Show Gist options
  • Save geffy/03d5f9a2cefc7b8feff993435fd3b139 to your computer and use it in GitHub Desktop.
Save geffy/03d5f9a2cefc7b8feff993435fd3b139 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"from scipy.sparse import rand as sprand\n",
"import torch\n",
"from torch import nn\n",
"from torch.utils.data import TensorDataset, DataLoader\n",
"from torch.autograd import Variable\n",
"import pickle\n",
"import tqdm\n",
"from sklearn.metrics import mean_squared_error"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"data = pickle.load(open('../data/processed/ml-1M.pickled', 'rb'))\n",
"\n",
"tr_ds = TensorDataset(torch.from_numpy(data['X_tr'].astype(np.long)), \n",
" torch.from_numpy(data['y_tr'].astype(np.float32)))\n",
"tr_iter = DataLoader(tr_ds, batch_size=1024, shuffle=True)\n",
"\n",
"te_ds = TensorDataset(torch.from_numpy(data['X_te'].astype(np.long)), \n",
" torch.from_numpy(data['y_te'].astype(np.float32)))\n",
"te_iter = DataLoader(te_ds, batch_size=1024, shuffle=False)\n",
"\n",
"n_users, n_items = np.max(data['X_tr'], axis=0)\n",
"n_users = int(n_users) + 1\n",
"n_items = int(n_items) + 1"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"# just Embedding layer with custom normal initialization\n",
"class ScaledEmbedding(nn.Embedding):\n",
" def reset_parameters(self):\n",
" self.weight.data.normal_(0, 0.1 / self.embedding_dim)\n",
" if self.padding_idx is not None:\n",
" self.weight.data[self.padding_idx].fill_(0)\n",
" \n",
"\n",
"class MatrixFactorization(torch.nn.Module):\n",
" \n",
" def __init__(self, n_users, n_items, n_factors=20):\n",
" super().__init__()\n",
" self.user_factors = ScaledEmbedding(n_users, n_factors)\n",
" self.item_factors = ScaledEmbedding(n_items, n_factors)\n",
" \n",
" def forward(self, user, item):\n",
" return (self.user_factors(user) * self.item_factors(item)).sum(1)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
"\n",
"model = MatrixFactorization(n_users, n_items, n_factors=60)\n",
"mse = torch.nn.MSELoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-9)\n",
"scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=3, threshold=3e-4, threshold_mode='abs', verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"# Train"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"def train():\n",
" loss_hist = []\n",
" penalty_hist = []\n",
" for batch_X, batch_y in tqdm.tqdm(tr_iter):\n",
" optimizer.zero_grad()\n",
" \n",
" bX = Variable(batch_X)\n",
" bY = Variable(batch_y.float())\n",
" prediction = model.forward(bX[:, 0], bX[:, 1])\n",
" loss = mse(prediction, bY)\n",
"\n",
" loss_hist.append(loss.data.numpy()[0])\n",
"\n",
" # Backpropagate\n",
" loss.backward()\n",
"\n",
" # Update the parameters\n",
" optimizer.step()\n",
"\n",
" print('mse: {}| rmse: {}'.format(np.mean(loss_hist), np.sqrt(np.mean(loss_hist))))\n",
" return loss_hist"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"# Test"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"def test():\n",
" pred_hist = []\n",
" for batch_X, batch_y in tqdm.tqdm(te_iter):\n",
" bX = Variable(batch_X)\n",
" bY = Variable(batch_y.float())\n",
" prediction = model.forward(bX[:, 0], bX[:, 1])\n",
" pred_hist.append(prediction.data.numpy())\n",
" return np.concatenate(pred_hist) "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true,
"scrolled": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 880/880 [00:11<00:00, 74.39it/s]\n",
" 14%|█▍ | 14/98 [00:00<00:00, 135.49it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mse: 4.653479099273682| rmse: 2.1571924686431885\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 98/98 [00:00<00:00, 177.92it/s]\n",
" 1%| | 5/880 [00:00<00:18, 47.60it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 1.01173530424\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 880/880 [00:12<00:00, 72.27it/s]\n",
" 22%|██▏ | 22/98 [00:00<00:00, 219.62it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mse: 0.9061034917831421| rmse: 0.9518947005271912\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 98/98 [00:00<00:00, 165.33it/s]\n",
" 0%| | 2/880 [00:00<00:51, 17.15it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 0.93067693311\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 880/880 [00:12<00:00, 70.76it/s]\n",
" 21%|██▏ | 21/98 [00:00<00:00, 161.11it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mse: 0.851354718208313| rmse: 0.9226888418197632\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 98/98 [00:00<00:00, 187.60it/s]\n",
" 1%| | 5/880 [00:00<00:23, 38.01it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2 0.9221900053\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 880/880 [00:11<00:00, 77.60it/s]\n",
" 14%|█▍ | 14/98 [00:00<00:00, 136.96it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mse: 0.8407105207443237| rmse: 0.9169026613235474\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 98/98 [00:00<00:00, 175.47it/s]\n",
" 1%| | 5/880 [00:00<00:18, 47.06it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"3 0.917020341644\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 880/880 [00:11<00:00, 76.72it/s]\n",
" 13%|█▎ | 13/98 [00:00<00:00, 128.18it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mse: 0.8308556079864502| rmse: 0.9115127921104431\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 98/98 [00:00<00:00, 173.40it/s]\n",
" 1%| | 5/880 [00:00<00:18, 47.75it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"4 0.911672158106\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 880/880 [00:11<00:00, 78.59it/s]\n",
" 21%|██▏ | 21/98 [00:00<00:00, 202.82it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mse: 0.8189589977264404| rmse: 0.9049635529518127\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 98/98 [00:00<00:00, 191.10it/s]\n",
" 1%| | 5/880 [00:00<00:18, 46.53it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"5 0.907034108872\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 880/880 [00:11<00:00, 77.61it/s]\n",
" 17%|█▋ | 17/98 [00:00<00:00, 141.66it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mse: 0.8076307773590088| rmse: 0.8986827731132507\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 98/98 [00:00<00:00, 173.19it/s]\n",
" 1%| | 5/880 [00:00<00:18, 48.50it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"6 0.902936388422\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 880/880 [00:11<00:00, 76.86it/s]\n",
" 19%|█▉ | 19/98 [00:00<00:00, 185.28it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mse: 0.7959813475608826| rmse: 0.8921778798103333\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 98/98 [00:00<00:00, 172.34it/s]\n",
" 1%| | 5/880 [00:00<00:18, 47.96it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"7 0.898566136779\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 880/880 [00:11<00:00, 74.25it/s]\n",
" 13%|█▎ | 13/98 [00:00<00:00, 128.63it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mse: 0.7819766402244568| rmse: 0.8842944502830505\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 98/98 [00:00<00:00, 169.65it/s]\n",
" 1%| | 5/880 [00:00<00:18, 47.47it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"8 0.891737408583\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 880/880 [00:11<00:00, 74.63it/s]\n",
" 20%|██ | 20/98 [00:00<00:00, 154.51it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mse: 0.7655534148216248| rmse: 0.8749591112136841\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 98/98 [00:00<00:00, 187.24it/s]\n",
" 0%| | 3/880 [00:00<00:31, 27.48it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"9 0.884484992531\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 880/880 [00:12<00:00, 78.19it/s]\n",
" 14%|█▍ | 14/98 [00:00<00:00, 138.73it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mse: 0.7477543354034424| rmse: 0.8647279143333435\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 98/98 [00:00<00:00, 174.74it/s]\n",
" 1%| | 5/880 [00:00<00:18, 46.39it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"10 0.878235862071\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 880/880 [00:11<00:00, 77.35it/s]\n",
" 13%|█▎ | 13/98 [00:00<00:00, 127.24it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mse: 0.7289107441902161| rmse: 0.853762686252594\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 98/98 [00:00<00:00, 173.27it/s]\n",
" 1%| | 5/880 [00:00<00:18, 48.17it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"11 0.872880039698\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 880/880 [00:11<00:00, 76.73it/s]\n",
" 19%|█▉ | 19/98 [00:00<00:00, 189.53it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mse: 0.708349883556366| rmse: 0.8416352272033691\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 98/98 [00:00<00:00, 184.58it/s]\n",
" 1%| | 5/880 [00:00<00:18, 47.70it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"12 0.866932788908\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 880/880 [00:11<00:00, 75.45it/s]\n",
" 20%|██ | 20/98 [00:00<00:00, 191.45it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mse: 0.6866269707679749| rmse: 0.8286295533180237\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 98/98 [00:00<00:00, 184.78it/s]\n",
" 0%| | 4/880 [00:00<00:22, 38.22it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"13 0.861228102471\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 880/880 [00:12<00:00, 68.93it/s]\n",
" 12%|█▏ | 12/98 [00:00<00:00, 112.82it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mse: 0.6639280319213867| rmse: 0.8148177862167358\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 98/98 [00:00<00:00, 133.63it/s]\n",
" 0%| | 3/880 [00:00<00:32, 26.69it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"14 0.857208540688\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 880/880 [00:12<00:00, 68.46it/s]\n",
" 13%|█▎ | 13/98 [00:00<00:00, 128.50it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mse: 0.6401769518852234| rmse: 0.8001105785369873\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 98/98 [00:00<00:00, 175.45it/s]\n",
" 1%| | 5/880 [00:00<00:19, 45.31it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"15 0.854500042774\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 880/880 [00:11<00:00, 76.96it/s]\n",
" 11%|█ | 11/98 [00:00<00:00, 109.56it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mse: 0.6150626540184021| rmse: 0.7842593193054199\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 98/98 [00:00<00:00, 170.92it/s]\n",
" 1%| | 5/880 [00:00<00:18, 47.23it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"16 0.852726459753\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 880/880 [00:11<00:00, 74.62it/s]\n",
" 20%|██ | 20/98 [00:00<00:00, 196.20it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mse: 0.5880409479141235| rmse: 0.7668382525444031\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 98/98 [00:00<00:00, 185.30it/s]\n",
" 1%| | 5/880 [00:00<00:17, 48.76it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"17 0.85159162539\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 880/880 [00:11<00:00, 77.37it/s]\n",
" 19%|█▉ | 19/98 [00:00<00:00, 189.47it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mse: 0.5601016879081726| rmse: 0.7483994364738464\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 98/98 [00:00<00:00, 183.51it/s]\n",
" 1%| | 5/880 [00:00<00:23, 37.87it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"18 0.851854084252\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 880/880 [00:11<00:00, 76.52it/s]\n",
" 14%|█▍ | 14/98 [00:00<00:00, 131.52it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mse: 0.5313001871109009| rmse: 0.7289034128189087\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 98/98 [00:00<00:00, 171.02it/s]\n",
" 1%| | 5/880 [00:00<00:18, 48.28it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"19 0.853770273338\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 880/880 [00:11<00:00, 73.87it/s]\n",
" 13%|█▎ | 13/98 [00:00<00:00, 128.89it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mse: 0.5032721161842346| rmse: 0.7094167470932007\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 98/98 [00:00<00:00, 171.69it/s]\n",
" 1%| | 5/880 [00:00<00:17, 48.91it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"20 0.857647674373\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 25%|██▍ | 217/880 [00:02<00:08, 73.79it/s]"
]
}
],
"source": [
"val_hist = []\n",
"for i in range(50):\n",
" loss_hist = train()\n",
" pred = test()\n",
" val_err = np.sqrt(mean_squared_error(data['y_te'], pred))\n",
" scheduler.step(val_err)\n",
" val_hist.append(val_err)\n",
" print(i, val_err)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"# SVD-20: 0.8623"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment