Skip to content

Instantly share code, notes, and snippets.

@srishilesh
Created March 22, 2020 14:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save srishilesh/673469c0814cc54902c708b755d567a4 to your computer and use it in GitHub Desktop.
Save srishilesh/673469c0814cc54902c708b755d567a4 to your computer and use it in GitHub Desktop.
Boston_house_Federated_learning
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Boston_house_Federated_learning",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/srishilesh/673469c0814cc54902c708b755d567a4/boston_house_federated_learning.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mzUVE2unkeFq",
"colab_type": "text"
},
"source": [
"**INSTALLING PYSYFT**"
]
},
{
"cell_type": "code",
"metadata": {
"id": "s-CaEV43c6WZ",
"colab_type": "code",
"outputId": "7e64fa04-81f5-4307-d9ec-cf0d2ffb6737",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 564
}
},
"source": [
"!pip install syft"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: syft in /usr/local/lib/python3.6/dist-packages (0.2.3)\n",
"Requirement already satisfied: torchvision~=0.5.0 in /usr/local/lib/python3.6/dist-packages (from syft) (0.5.0)\n",
"Requirement already satisfied: phe~=1.4.0 in /usr/local/lib/python3.6/dist-packages (from syft) (1.4.0)\n",
"Requirement already satisfied: msgpack~=1.0.0 in /usr/local/lib/python3.6/dist-packages (from syft) (1.0.0)\n",
"Requirement already satisfied: flask-socketio~=4.2.1 in /usr/local/lib/python3.6/dist-packages (from syft) (4.2.1)\n",
"Requirement already satisfied: lz4~=3.0.2 in /usr/local/lib/python3.6/dist-packages (from syft) (3.0.2)\n",
"Requirement already satisfied: numpy~=1.18.1 in /usr/local/lib/python3.6/dist-packages (from syft) (1.18.1)\n",
"Requirement already satisfied: zstd~=1.4.4.0 in /usr/local/lib/python3.6/dist-packages (from syft) (1.4.4.0)\n",
"Requirement already satisfied: scipy~=1.4.1 in /usr/local/lib/python3.6/dist-packages (from syft) (1.4.1)\n",
"Requirement already satisfied: torch~=1.4.0 in /usr/local/lib/python3.6/dist-packages (from syft) (1.4.0)\n",
"Requirement already satisfied: tblib~=1.6.0 in /usr/local/lib/python3.6/dist-packages (from syft) (1.6.0)\n",
"Requirement already satisfied: websocket-client~=0.57.0 in /usr/local/lib/python3.6/dist-packages (from syft) (0.57.0)\n",
"Requirement already satisfied: websockets~=8.1.0 in /usr/local/lib/python3.6/dist-packages (from syft) (8.1)\n",
"Requirement already satisfied: Pillow~=6.2.2 in /usr/local/lib/python3.6/dist-packages (from syft) (6.2.2)\n",
"Requirement already satisfied: requests~=2.22.0 in /usr/local/lib/python3.6/dist-packages (from syft) (2.22.0)\n",
"Requirement already satisfied: syft-proto~=0.2.1.a1.post2 in /usr/local/lib/python3.6/dist-packages (from syft) (0.2.1a1.post2)\n",
"Requirement already satisfied: Flask~=1.1.1 in /usr/local/lib/python3.6/dist-packages (from syft) (1.1.1)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from torchvision~=0.5.0->syft) (1.12.0)\n",
"Requirement already satisfied: python-socketio>=4.3.0 in /usr/local/lib/python3.6/dist-packages (from flask-socketio~=4.2.1->syft) (4.4.0)\n",
"Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests~=2.22.0->syft) (3.0.4)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests~=2.22.0->syft) (1.24.3)\n",
"Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests~=2.22.0->syft) (2.8)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests~=2.22.0->syft) (2019.11.28)\n",
"Requirement already satisfied: protobuf>=3.11.1 in /usr/local/lib/python3.6/dist-packages (from syft-proto~=0.2.1.a1.post2->syft) (3.11.3)\n",
"Requirement already satisfied: Werkzeug>=0.15 in /usr/local/lib/python3.6/dist-packages (from Flask~=1.1.1->syft) (1.0.0)\n",
"Requirement already satisfied: itsdangerous>=0.24 in /usr/local/lib/python3.6/dist-packages (from Flask~=1.1.1->syft) (1.1.0)\n",
"Requirement already satisfied: click>=5.1 in /usr/local/lib/python3.6/dist-packages (from Flask~=1.1.1->syft) (7.0)\n",
"Requirement already satisfied: Jinja2>=2.10.1 in /usr/local/lib/python3.6/dist-packages (from Flask~=1.1.1->syft) (2.11.1)\n",
"Requirement already satisfied: python-engineio>=3.9.0 in /usr/local/lib/python3.6/dist-packages (from python-socketio>=4.3.0->flask-socketio~=4.2.1->syft) (3.11.2)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.11.1->syft-proto~=0.2.1.a1.post2->syft) (45.2.0)\n",
"Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.6/dist-packages (from Jinja2>=2.10.1->Flask~=1.1.1->syft) (1.1.1)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3zpjJFxZkq2F",
"colab_type": "text"
},
"source": [
"**IMPORTING PACKAGES**"
]
},
{
"cell_type": "code",
"metadata": {
"id": "FHKVCwEJbpGL",
"colab_type": "code",
"colab": {}
},
"source": [
"import pickle\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"from torch.utils.data import TensorDataset, DataLoader\n",
"import time\n",
"import copy\n",
"import numpy as np\n",
"import syft as sy\n",
"from syft.frameworks.torch.fl import utils\n",
"from syft.workers.websocket_client import WebsocketClientWorker"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "SJRQXEAGk3qM",
"colab_type": "text"
},
"source": [
"**INITIALIZING THE PARAMETERS**"
]
},
{
"cell_type": "code",
"metadata": {
"id": "LpyLPcFicm2w",
"colab_type": "code",
"outputId": "3aa97652-7f4a-4d16-f367-2d6f11f4cfd8",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"class Parser:\n",
" def __init__(self):\n",
" self.epochs = 100\n",
" self.lr = 0.001\n",
" self.test_batch_size = 8\n",
" self.batch_size = 8\n",
" self.log_interval = 10\n",
" self.seed = 1\n",
"args = Parser()\n",
"torch.manual_seed(args.seed)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<torch._C.Generator at 0x7f2a2315b190>"
]
},
"metadata": {
"tags": []
},
"execution_count": 59
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PB_W4udlk9KZ",
"colab_type": "text"
},
"source": [
"**DATA PREPROCESSING**"
]
},
{
"cell_type": "code",
"metadata": {
"id": "IvAy-a6oiHNJ",
"colab_type": "code",
"colab": {}
},
"source": [
"with open('./boston_housing.pickle','rb') as f:\n",
" ((x, y), (x_test, y_test)) = pickle.load(f)\n",
"\n",
"x = torch.from_numpy(x).float()\n",
"y = torch.from_numpy(y).float()\n",
"x_test = torch.from_numpy(x_test).float()\n",
"y_test = torch.from_numpy(y_test).float()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "0waqckO-iMef",
"colab_type": "code",
"colab": {}
},
"source": [
"mean = x.mean(0, keepdim=True)\n",
"dev = x.std(0, keepdim=True)\n",
"mean[:, 3] = 0.\n",
"dev[:, 3] = 1.\n",
"x = (x - mean) / dev\n",
"x_test = (x_test - mean) / dev\n",
"train = TensorDataset(x, y)\n",
"test = TensorDataset(x_test, y_test)\n",
"train_loader = DataLoader(train, batch_size=args.batch_size, shuffle=True)\n",
"test_loader = DataLoader(test, batch_size=args.test_batch_size, shuffle=True)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "L62e-zZUlFFT",
"colab_type": "text"
},
"source": [
"**NEURAL NETWORK ARCHITECTURE**"
]
},
{
"cell_type": "code",
"metadata": {
"id": "QbdQ-SIjjOLY",
"colab_type": "code",
"colab": {}
},
"source": [
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
" self.fc1 = nn.Linear(13, 32)\n",
" self.fc2 = nn.Linear(32, 24)\n",
" self.fc4 = nn.Linear(24, 16)\n",
" self.fc3 = nn.Linear(16, 1)\n",
"\n",
" def forward(self, x):\n",
" x = x.view(-1, 13)\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
" x = F.relu(self.fc4(x))\n",
" x = self.fc3(x)\n",
" return x"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "kGSxtU7Klv2T",
"colab_type": "text"
},
"source": [
"**CONNECTING TO WORKERS FOR TRAINING**"
]
},
{
"cell_type": "code",
"metadata": {
"id": "d02WcKOnjZP4",
"colab_type": "code",
"outputId": "fef2e23e-651b-44dc-c9bd-9233cc142a58",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"hook = sy.TorchHook(torch)\n",
"bob_worker = sy.VirtualWorker(hook, id=\"bob\")\n",
"alice_worker = sy.VirtualWorker(hook, id=\"alice\")\n",
"# kwargs_websocket = {\"host\": \"localhost\", \"hook\": hook}\n",
"# alice = WebsocketClientWorker(id='alice', port=8779, **kwargs_websocket)\n",
"# bob = WebsocketClientWorker(id='bob', port=8778, **kwargs_websocket)\n",
"compute_nodes = [bob_worker, alice_worker]"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"WARNING:root:Torch was already hooked... skipping hooking process\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0Kbjc4bIlknW",
"colab_type": "text"
},
"source": [
"**CONNECTING THE DATA WITH REMOTE DEVICES**"
]
},
{
"cell_type": "code",
"metadata": {
"id": "GKBl3e9Djb1t",
"colab_type": "code",
"colab": {}
},
"source": [
"remote_dataset = (list(), list())\n",
"train_distributed_dataset = []\n",
"\n",
"for batch_idx, (data,target) in enumerate(train_loader):\n",
" data = data.send(compute_nodes[batch_idx % len(compute_nodes)])\n",
" target = target.send(compute_nodes[batch_idx % len(compute_nodes)])\n",
" remote_dataset[batch_idx % len(compute_nodes)].append((data, target))"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "lP1eA8Tqjgs6",
"colab_type": "code",
"colab": {}
},
"source": [
"bobs_model = Net()\n",
"alices_model = Net()\n",
"bobs_optimizer = optim.SGD(bobs_model.parameters(), lr=args.lr)\n",
"alices_optimizer = optim.SGD(alices_model.parameters(), lr=args.lr)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "TJPN12FEEa9B",
"colab_type": "text"
},
"source": [
"**WEIGHTS BEFORE TRAINING**"
]
},
{
"cell_type": "code",
"metadata": {
"id": "CasnRJgpER41",
"colab_type": "code",
"outputId": "9709955e-86a0-4163-e99f-8a9db8aee520",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
}
},
"source": [
"bobs_model.fc3.bias"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Parameter containing:\n",
"tensor([-0.0842], requires_grad=True)"
]
},
"metadata": {
"tags": []
},
"execution_count": 66
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "1LyM4a3gEeit",
"colab_type": "code",
"outputId": "147efd1b-328a-410c-8899-4e6f1f2f3acd",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
}
},
"source": [
"alices_model.fc3.bias"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Parameter containing:\n",
"tensor([-0.0982], requires_grad=True)"
]
},
"metadata": {
"tags": []
},
"execution_count": 67
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "DMxUZ4gOjivJ",
"colab_type": "code",
"colab": {}
},
"source": [
"models = [bobs_model, alices_model]\n",
"optimizers = [bobs_optimizer, alices_optimizer]"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "5JQLJc8njkxV",
"colab_type": "code",
"outputId": "0bd16c99-5503-46a9-922d-ba1353b0e0bd",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 119
}
},
"source": [
"model = Net()\n",
"model"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Net(\n",
" (fc1): Linear(in_features=13, out_features=32, bias=True)\n",
" (fc2): Linear(in_features=32, out_features=24, bias=True)\n",
" (fc4): Linear(in_features=24, out_features=16, bias=True)\n",
" (fc3): Linear(in_features=16, out_features=1, bias=True)\n",
")"
]
},
"metadata": {
"tags": []
},
"execution_count": 69
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eNO2Z8vEl6F1",
"colab_type": "text"
},
"source": [
"**TRAINING**"
]
},
{
"cell_type": "code",
"metadata": {
"id": "zTzt6bLyjmlr",
"colab_type": "code",
"colab": {}
},
"source": [
"def update(data, target, model, optimizer):\n",
" model.send(data.location)\n",
" optimizer.zero_grad()\n",
" prediction = model(data)\n",
" loss = F.mse_loss(prediction.view(-1), target)\n",
" loss.backward()\n",
" optimizer.step()\n",
" return model\n",
"\n",
"def train():\n",
" for data_index in range(len(remote_dataset[0])-1):\n",
" for remote_index in range(len(compute_nodes)):\n",
" data, target = remote_dataset[remote_index][data_index]\n",
" models[remote_index] = update(data, target, models[remote_index], optimizers[remote_index])\n",
" for model in models:\n",
" model.get()\n",
" return utils.federated_avg({\n",
" \"bob\": models[0],\n",
" \"alice\": models[1]\n",
" })"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "OUEH67a6jpkn",
"colab_type": "code",
"colab": {}
},
"source": [
"def test(federated_model):\n",
" federated_model.eval()\n",
" test_loss = 0\n",
" for data, target in test_loader:\n",
" output = federated_model(data)\n",
" test_loss += F.mse_loss(output.view(-1), target, reduction='sum').item()\n",
" predection = output.data.max(1, keepdim=True)[1]\n",
" \n",
" test_loss /= len(test_loader.dataset)\n",
" print('Test set: Average loss: {:.4f}'.format(test_loss))"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "W6MiwI_djryn",
"colab_type": "code",
"outputId": "4e03875d-7778-4ebf-bd98-5720bc868ac0",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
}
},
"source": [
"for epoch in range(args.epochs):\n",
" start_time = time.time()\n",
" print(f\"Epoch Number {epoch + 1}\")\n",
" federated_model = train()\n",
" model = federated_model\n",
" test(federated_model)\n",
" total_time = time.time() - start_time\n",
" print('Communication time over the network', round(total_time, 2), 's\\n')\n"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Epoch Number 1\n",
"Test set: Average loss: 615.8278\n",
"Communication time over the network 0.09 s\n",
"\n",
"Epoch Number 2\n",
"Test set: Average loss: 613.6289\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 3\n",
"Test set: Average loss: 610.8525\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 4\n",
"Test set: Average loss: 607.9232\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 5\n",
"Test set: Average loss: 604.9781\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 6\n",
"Test set: Average loss: 602.0598\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 7\n",
"Test set: Average loss: 599.1488\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 8\n",
"Test set: Average loss: 596.2221\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 9\n",
"Test set: Average loss: 593.2520\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 10\n",
"Test set: Average loss: 590.2224\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 11\n",
"Test set: Average loss: 587.1091\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 12\n",
"Test set: Average loss: 583.8926\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 13\n",
"Test set: Average loss: 580.5557\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 14\n",
"Test set: Average loss: 577.0765\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 15\n",
"Test set: Average loss: 573.4352\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 16\n",
"Test set: Average loss: 569.6040\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 17\n",
"Test set: Average loss: 565.5632\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 18\n",
"Test set: Average loss: 561.2832\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 19\n",
"Test set: Average loss: 556.7154\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 20\n",
"Test set: Average loss: 551.8287\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 21\n",
"Test set: Average loss: 546.5705\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 22\n",
"Test set: Average loss: 540.8797\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 23\n",
"Test set: Average loss: 534.6752\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 24\n",
"Test set: Average loss: 527.8565\n",
"Communication time over the network 0.09 s\n",
"\n",
"Epoch Number 25\n",
"Test set: Average loss: 520.3152\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 26\n",
"Test set: Average loss: 511.9150\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 27\n",
"Test set: Average loss: 502.4826\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 28\n",
"Test set: Average loss: 491.8018\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 29\n",
"Test set: Average loss: 479.5735\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 30\n",
"Test set: Average loss: 465.3832\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 31\n",
"Test set: Average loss: 448.7991\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 32\n",
"Test set: Average loss: 429.2040\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 33\n",
"Test set: Average loss: 405.8127\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 34\n",
"Test set: Average loss: 377.8263\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 35\n",
"Test set: Average loss: 344.3854\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 36\n",
"Test set: Average loss: 304.8530\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 37\n",
"Test set: Average loss: 259.4683\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 38\n",
"Test set: Average loss: 210.3587\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 39\n",
"Test set: Average loss: 162.4481\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 40\n",
"Test set: Average loss: 122.8936\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 41\n",
"Test set: Average loss: 96.7920\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 42\n",
"Test set: Average loss: 82.9358\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 43\n",
"Test set: Average loss: 76.2781\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 44\n",
"Test set: Average loss: 72.7867\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 45\n",
"Test set: Average loss: 70.5361\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 46\n",
"Test set: Average loss: 68.9330\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 47\n",
"Test set: Average loss: 67.5992\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 48\n",
"Test set: Average loss: 66.4083\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 49\n",
"Test set: Average loss: 65.3180\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 50\n",
"Test set: Average loss: 64.2877\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 51\n",
"Test set: Average loss: 63.3141\n",
"Communication time over the network 0.1 s\n",
"\n",
"Epoch Number 52\n",
"Test set: Average loss: 62.3876\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 53\n",
"Test set: Average loss: 61.4981\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 54\n",
"Test set: Average loss: 60.6435\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 55\n",
"Test set: Average loss: 59.8251\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 56\n",
"Test set: Average loss: 59.0407\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 57\n",
"Test set: Average loss: 58.2781\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 58\n",
"Test set: Average loss: 57.5338\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 59\n",
"Test set: Average loss: 56.8279\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 60\n",
"Test set: Average loss: 56.1354\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 61\n",
"Test set: Average loss: 55.4626\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 62\n",
"Test set: Average loss: 54.8033\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 63\n",
"Test set: Average loss: 54.1732\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 64\n",
"Test set: Average loss: 53.5623\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 65\n",
"Test set: Average loss: 52.9653\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 66\n",
"Test set: Average loss: 52.3816\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 67\n",
"Test set: Average loss: 51.8223\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 68\n",
"Test set: Average loss: 51.2816\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 69\n",
"Test set: Average loss: 50.7554\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 70\n",
"Test set: Average loss: 50.2426\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 71\n",
"Test set: Average loss: 49.7441\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 72\n",
"Test set: Average loss: 49.2661\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 73\n",
"Test set: Average loss: 48.8044\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 74\n",
"Test set: Average loss: 48.3548\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 75\n",
"Test set: Average loss: 47.9030\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 76\n",
"Test set: Average loss: 47.4720\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 77\n",
"Test set: Average loss: 47.0545\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 78\n",
"Test set: Average loss: 46.6433\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 79\n",
"Test set: Average loss: 46.2404\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 80\n",
"Test set: Average loss: 45.8404\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 81\n",
"Test set: Average loss: 45.4649\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 82\n",
"Test set: Average loss: 45.0982\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 83\n",
"Test set: Average loss: 44.7320\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 84\n",
"Test set: Average loss: 44.4494\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 85\n",
"Test set: Average loss: 44.0800\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 86\n",
"Test set: Average loss: 43.7343\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 87\n",
"Test set: Average loss: 43.4088\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 88\n",
"Test set: Average loss: 43.1473\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 89\n",
"Test set: Average loss: 42.8142\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 90\n",
"Test set: Average loss: 42.4894\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 91\n",
"Test set: Average loss: 42.2626\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 92\n",
"Test set: Average loss: 41.9549\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 93\n",
"Test set: Average loss: 41.6637\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 94\n",
"Test set: Average loss: 41.4593\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 95\n",
"Test set: Average loss: 41.1788\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 96\n",
"Test set: Average loss: 40.9285\n",
"Communication time over the network 0.08 s\n",
"\n",
"Epoch Number 97\n",
"Test set: Average loss: 40.7471\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 98\n",
"Test set: Average loss: 40.4832\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 99\n",
"Test set: Average loss: 40.2277\n",
"Communication time over the network 0.07 s\n",
"\n",
"Epoch Number 100\n",
"Test set: Average loss: 40.0887\n",
"Communication time over the network 0.07 s\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Tb1auHcQjt8E",
"colab_type": "code",
"outputId": "79218f41-bdba-422f-c029-7a4dce4665ef",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 204
}
},
"source": [
"models"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[Net(\n",
" (fc1): Linear(in_features=13, out_features=32, bias=True)\n",
" (fc2): Linear(in_features=32, out_features=24, bias=True)\n",
" (fc4): Linear(in_features=24, out_features=16, bias=True)\n",
" (fc3): Linear(in_features=16, out_features=1, bias=True)\n",
" ), Net(\n",
" (fc1): Linear(in_features=13, out_features=32, bias=True)\n",
" (fc2): Linear(in_features=32, out_features=24, bias=True)\n",
" (fc4): Linear(in_features=24, out_features=16, bias=True)\n",
" (fc3): Linear(in_features=16, out_features=1, bias=True)\n",
" )]"
]
},
"metadata": {
"tags": []
},
"execution_count": 73
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QzPTZ8UcEjnI",
"colab_type": "text"
},
"source": [
"**WEIGHTS AFTER UPDATION**"
]
},
{
"cell_type": "code",
"metadata": {
"id": "G5wyrGImDbDW",
"colab_type": "code",
"outputId": "a0df5917-c7a3-418a-9c0b-60e0d23fe220",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
}
},
"source": [
"bobs_model.fc3.bias"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Parameter containing:\n",
"tensor([1.3315], requires_grad=True)"
]
},
"metadata": {
"tags": []
},
"execution_count": 74
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "YyaEGIkNCvDA",
"colab_type": "code",
"outputId": "c25cb10d-d20f-4946-bada-0fd7fe07dfa5",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
}
},
"source": [
"alices_model.fc3.bias"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Parameter containing:\n",
"tensor([1.3244], requires_grad=True)"
]
},
"metadata": {
"tags": []
},
"execution_count": 75
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "iWF9qJbNEHQ9",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment