Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save leriomaggio/30fc8892e1c6b7f5b3b23b46e734e5c9 to your computer and use it in GitHub Desktop.
Save leriomaggio/30fc8892e1c6b7f5b3b23b46e734e5c9 to your computer and use it in GitHub Desktop.
SOLVED:Flower Hands-on Tutorial-PyTorch.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/leriomaggio/30fc8892e1c6b7f5b3b23b46e734e5c9/solved-flower-hands-on-tutorial-pytorch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r41cmZoW1CF8"
},
"source": [
"# Flower Hands on Coding Challange\n",
"\n",
"Welcome to Flower Hands-on Tutorial!\n",
"\n",
"In this notebook, you'll build a federated learning system using **Flower** and **PyTorch** with **MNIST** dataset.\n",
"\n",
"We will provide you with the following code:\n",
"* ML model definition,\n",
"* train, test functions.\n",
"\n",
"You will need to implement the following elements:\n",
"* data division (clients partitioning and train/test division),\n",
"* Flower Client class,\n",
"* Flower Strategy (initialization or groud-up creation),\n",
"* distributed evaluation,\n",
"* centralized evaluation.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VHJaF8yX1CF-"
},
"source": [
"### Install dependencies & Import libraries\n",
"\n",
"Next, we install the necessary packages\n",
"* Flower (`flwr`),\n",
"* PyTorch (`torch`, `torchvision`),"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5jU4leX31CF-",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "7d0a4b73-5ea3-4670-c792-4240d70dd6a4"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m157.2/157.2 kB\u001b[0m \u001b[31m4.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.6/58.6 MB\u001b[0m \u001b[31m17.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m149.6/149.6 kB\u001b[0m \u001b[31m10.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.7/8.7 MB\u001b[0m \u001b[31m77.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.8/4.8 MB\u001b[0m \u001b[31m45.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m64.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m201.4/201.4 kB\u001b[0m \u001b[31m24.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.0/3.0 MB\u001b[0m \u001b[31m64.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m97.9/97.9 kB\u001b[0m \u001b[31m12.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m128.2/128.2 kB\u001b[0m \u001b[31m17.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m114.5/114.5 kB\u001b[0m \u001b[31m14.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m268.8/268.8 kB\u001b[0m \u001b[31m33.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.4/58.4 kB\u001b[0m \u001b[31m8.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m468.5/468.5 kB\u001b[0m \u001b[31m53.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h Building wheel for gpustat (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n"
]
}
],
"source": [
"!pip install -q flwr[simulation] torch torchvision matplotlib;"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dEG3b87r1CF_"
},
"source": [
"Now that we have all dependencies installed, we can import everything we need for this tutorial:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "MlbJRGwP1CF_",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "270d7be6-a045-481d-8231-438334766816"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Training on cpu using PyTorch 2.0.1+cu118 and Flower 1.4.0\n"
]
}
],
"source": [
"from collections import OrderedDict\n",
"from typing import List, Tuple, Dict, Optional\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"from torch.utils.data import DataLoader, random_split\n",
"from torchvision.datasets import MNIST\n",
"from torch.utils.data import ConcatDataset, DataLoader, Dataset, Subset, random_split\n",
"\n",
"import flwr as fl\n",
"from flwr.common import Metrics\n",
"from flwr.common.typing import NDArrays, Scalar\n",
"\n",
"DEVICE = torch.device(\"cpu\") # Try \"cuda\" to train on GPU\n",
"print(\n",
" f\"Training on {DEVICE} using PyTorch {torch.__version__} and Flower {fl.__version__}\"\n",
")\n"
]
},
{
"cell_type": "markdown",
"source": [
"Let's define some useful constants that we will need along the tutorial."
],
"metadata": {
"id": "K5pZIQByAgFb"
}
},
{
"cell_type": "code",
"source": [
"SEED = 42\n",
"NUM_CLIENTS = 10\n",
"BATCH_SIZE = 32\n",
"VALID_FRACTION = 0.2 # fraction of the dataset used for each local client\n",
"NUM_ROUNDS = 2"
],
"metadata": {
"id": "G5YZJIkvzfc-"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Data\n",
"\n",
"The aim of this section is to create a divided MNIST dataset to simulate the federated learning evironment.\n",
"\n",
"We provide you with the `download_data` function and want you to implement the following:\n",
"\n",
"* `partition_data`,\n",
"* `train_val_divide_local_datasets`.\n",
"\n",
"You are given the function prototype - function name, the necessary information about a function, arguments, their types, and the return type.\n",
"\n",
"If implemented correctly they should be able to run `load_datasets` function that creates the divided datasets.\n",
"\n",
"Firstly let's just have a quick look at the data (already prepared, just run the cells below)."
],
"metadata": {
"id": "qmi_O9qspq4J"
}
},
{
"cell_type": "code",
"source": [
"def download_data() -> Tuple[Dataset, Dataset]:\n",
" transform = transforms.Compose(\n",
" [transforms.ToTensor(),]# transforms.Normalize((0.1307,), (0.3081,))\n",
" )\n",
" trainset = MNIST(\"./dataset\", train=True, download=True, transform=transform)\n",
" testset = MNIST(\"./dataset\", train=False, download=True, transform=transform)\n",
" return trainset, testset\n",
"# Keep the testset for centralized (optional) centralized evaluation.\n",
"trainset, testset = download_data()"
],
"metadata": {
"id": "RaHMfzZipqdQ",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "54b14f43-ab49-45c3-9363-4c2365f8cba2"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./dataset/MNIST/raw/train-images-idx3-ubyte.gz\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"100%|██████████| 9912422/9912422 [00:00<00:00, 101700593.79it/s]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Extracting ./dataset/MNIST/raw/train-images-idx3-ubyte.gz to ./dataset/MNIST/raw\n",
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./dataset/MNIST/raw/train-labels-idx1-ubyte.gz\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"100%|██████████| 28881/28881 [00:00<00:00, 53176336.18it/s]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Extracting ./dataset/MNIST/raw/train-labels-idx1-ubyte.gz to ./dataset/MNIST/raw\n",
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./dataset/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"100%|██████████| 1648877/1648877 [00:00<00:00, 24754514.11it/s]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Extracting ./dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to ./dataset/MNIST/raw\n",
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"100%|██████████| 4542/4542 [00:00<00:00, 16681723.96it/s]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Extracting ./dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./dataset/MNIST/raw\n",
"\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"### Quick EDA"
],
"metadata": {
"id": "ube7hvyewiha"
}
},
{
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"dataiter = iter(torch.utils.data.DataLoader(trainset, batch_size=64))\n",
"images, labels = next(dataiter)\n",
"images = images.permute(0, 2, 3, 1).numpy()\n",
"\n",
"# Create a figure and a grid of subplots\n",
"fig, axs = plt.subplots(4, 8, figsize=(10, 4))\n",
"\n",
"# Loop over the images and plot them\n",
"for i, ax in enumerate(axs.flat):\n",
" ax.imshow(images[i], cmap='gray')\n",
" ax.set_title(labels[i].numpy())\n",
" ax.axis(\"off\")\n",
"\n",
"# Show the plot\n",
"fig.tight_layout()\n",
"plt.show()\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 407
},
"id": "Swp-IN2Zqtto",
"outputId": "abe49e71-63f1-453f-e791-e119e77ede88"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1000x400 with 32 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"labels = pd.Series([label for _, label in trainset])\n",
"counts = labels.value_counts().sort_index()\n",
"counts.plot.bar(figsize=(7, 4), color=[\"#F2B705\"])\n",
"plt.ylabel(\"Count\", fontsize=16)\n",
"plt.xlabel(\"Label\", fontsize=16)\n",
"plt.xticks(fontsize=14, rotation=0)\n",
"plt.yticks(fontsize=12)\n",
"ax = plt.gca()\n",
"plt.tight_layout()\n",
"ax.spines['top'].set_visible(False)\n",
"ax.spines['right'].set_visible(False)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 406
},
"id": "9nRkBQ_UvZHw",
"outputId": "0efdd1d6-b84a-4ea6-de08-276be6ad2ca0"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 700x400 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"We see that the data is quite evenly distributed. Let's sample the data random - iid sampling."
],
"metadata": {
"id": "3_d59TyjwqDv"
}
},
{
"cell_type": "code",
"source": [
"def partition_data(dataset: Dataset, n_partitions: int) -> List[Dataset]:\n",
" \"\"\"\n",
" Split the dataset into iid partitions to simulate federated learning.\n",
"\n",
" Returns\n",
" -------\n",
" List[Dataset]\n",
" A list of dataset (one dataset for every client)\n",
" \"\"\"\n",
" partition_size = int(len(dataset) / n_partitions)\n",
" lengths = [partition_size] * n_partitions\n",
" datasets = random_split(dataset, lengths, torch.Generator().manual_seed(SEED))\n",
" return datasets"
],
"metadata": {
"id": "VaPcBizQwhMv"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def train_val_divide_local_datasets(local_datasets: List[Dataset], valid_fraction: float) -> Tuple[List[Dataset], List[Dataset]]:\n",
" \"\"\"Split each local dataset into train and validation.\"\"\"\n",
" trainloaders = []\n",
" validloaders = []\n",
"\n",
" for dataset in local_datasets:\n",
" validation_lenght = int(len(dataset) * valid_fraction)\n",
" train_length = len(dataset) - validation_lenght\n",
" lengths = [train_length, validation_lenght]\n",
" train_dataset, validation_dataset = random_split(\n",
" dataset, lengths, torch.Generator().manual_seed(SEED)\n",
" )\n",
" trainloaders.append(DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True))\n",
" validloaders.append(DataLoader(validation_dataset, batch_size=BATCH_SIZE))\n",
"\n",
" return trainloaders, validloaders"
],
"metadata": {
"id": "p--8Tq020Kg7"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def load_datasets(n_partitions: int, valid_fraction: float, batch_size:int) -> Tuple[List[DataLoader], List[DataLoader], DataLoader]:\n",
" \"\"\"Handles the MNIST data creation for federated learning.\n",
"\n",
" It starts from downloading, thought partitioning, train test division and centralized dataset creation.\n",
"\n",
" Parameters\n",
" ----------\n",
" n_partitions: int\n",
" The number of partitions the MNIST train set is divided into.\n",
" valid_split: float\n",
" The fraction of the validaiton data in each local dataset.\n",
" batch_size: int\n",
" The size of batch.\n",
"\n",
" Returns\n",
" -------\n",
" Tuple[List[DataLoader], List[DataLoader], DataLoader]\n",
" Local train datasets, local validation datasets, and a centralized dataset\n",
" \"\"\"\n",
" # DO NOT MODIFY THIS CODE\n",
" trainset, testset = download_data()\n",
" local_datasets = partition_data(trainset, n_partitions)\n",
" trainloaders, validloaders = train_val_divide_local_datasets(local_datasets, valid_fraction)\n",
" centralized_loader = DataLoader(testset, batch_size=batch_size)\n",
" return trainloaders, validloaders, centralized_loader"
],
"metadata": {
"id": "mI0JmcfA0ETs"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# DO NOT MODIFY THIS CODE\n",
"trainloaders, validloaders, centralized_loader = load_datasets(\n",
" n_partitions=NUM_CLIENTS,\n",
" valid_fraction=VALID_FRACTION,\n",
" batch_size=BATCH_SIZE)"
],
"metadata": {
"id": "9Ul4N6Fy1krn"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"assert len(trainloaders) == NUM_CLIENTS, f\"The number of train partitions should be equal to the number of clients = {NUM_CLIENTS} but is {len(trainloaders)} instead\"\n",
"assert len(validloaders) == NUM_CLIENTS, f\"The number of validation partitions should be equal to the number of clients = {NUM_CLIENTS} but is {len(validloaders)} instead\""
],
"metadata": {
"id": "POpUjm78bmEw"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "hg0T11jr1CGB"
},
"source": [
"## Test Solution using Centralized Training\n",
"\n",
"In this section you are not required to implement anything. You can test your solution by doing centralized training on one of the partitions of the data by simply running the code below.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "KmBV1kh31CGB"
},
"outputs": [],
"source": [
"class Net(nn.Module):\n",
" \"\"\"Basic CNN implementation\"\"\"\n",
" def __init__(self) -> None:\n",
" super(Net, self).__init__()\n",
" self.conv1 = nn.Conv2d(1, 32, 5, padding=\"same\")\n",
" self.pool = nn.MaxPool2d(2, 2)\n",
" self.conv2 = nn.Conv2d(32, 64, 5, padding=\"same\")\n",
" self.fc1 = nn.Linear(64 * 7 * 7, 2048)\n",
" self.fc2 = nn.Linear(2048, 10)\n",
"\n",
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
" x = self.pool(F.relu(self.conv1(x)))\n",
" x = self.pool(F.relu(self.conv2(x)))\n",
" x = x.view(-1, 7 * 7 * 64)\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wFYFDFc81CGB"
},
"source": [
"Let's have a look at the usual training and test functions:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NICmNQgb1CGC"
},
"outputs": [],
"source": [
"def train(net: nn.Module, trainloader: DataLoader, epochs: int, verbose=False):\n",
" \"\"\"Train the neural network for a classification task.\"\"\"\n",
" criterion = torch.nn.CrossEntropyLoss()\n",
" optimizer = torch.optim.Adam(net.parameters())\n",
" net.train()\n",
" for epoch in range(epochs):\n",
" correct, total, epoch_loss = 0, 0, 0.0\n",
" for images, labels in trainloader:\n",
" images, labels = images.to(DEVICE), labels.to(DEVICE)\n",
" optimizer.zero_grad()\n",
" outputs = net(images)\n",
" loss = criterion(outputs, labels)\n",
" loss.backward()\n",
" optimizer.step()\n",
" # Metrics\n",
" epoch_loss += loss\n",
" total += labels.size(0)\n",
" correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()\n",
" epoch_loss /= len(trainloader.dataset)\n",
" epoch_acc = correct / total\n",
" if verbose:\n",
" print(f\"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}\")\n",
"\n",
"\n",
"def test(net: nn.Module, testloader: DataLoader):\n",
" \"\"\"Test the neural network used for classification task.\"\"\"\n",
" criterion = torch.nn.CrossEntropyLoss()\n",
" correct, total, loss = 0, 0, 0.0\n",
" net.eval()\n",
" with torch.no_grad():\n",
" for images, labels in testloader:\n",
" images, labels = images.to(DEVICE), labels.to(DEVICE)\n",
" outputs = net(images)\n",
" loss += criterion(outputs, labels).item()\n",
" _, predicted = torch.max(outputs.data, 1)\n",
" total += labels.size(0)\n",
" correct += (predicted == labels).sum().item()\n",
" loss /= len(testloader.dataset)\n",
" accuracy = correct / total\n",
" return loss, accuracy"
]
},
{
"cell_type": "markdown",
"source": [
"Run the centralized training."
],
"metadata": {
"id": "I6rEAgbo5KC_"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QhVEhZbc1CGC",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "a7255c53-ec22-4259-e66a-d43f012d7304"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 1: validation loss 0.04788221110900243, accuracy 0.37083333333333335\n",
"Epoch 2: validation loss 0.047523756523927055, accuracy 0.37083333333333335\n",
"Epoch 3: validation loss 0.046505961616834006, accuracy 0.37666666666666665\n",
"Epoch 4: validation loss 0.04789999047915141, accuracy 0.36333333333333334\n",
"Epoch 5: validation loss 0.04795777161916097, accuracy 0.37333333333333335\n",
"Final test set performance:\n",
"\tloss 0.04549221425056457\n",
"\taccuracy 0.3926\n"
]
}
],
"source": [
"trainloader = trainloaders[0]\n",
"valloader = validloaders[0]\n",
"net = Net().to(DEVICE)\n",
"\n",
"for epoch in range(5):\n",
" train(net, trainloader, 1)\n",
" loss, accuracy = test(net, valloader)\n",
" print(f\"Epoch {epoch+1}: validation loss {loss}, accuracy {accuracy}\")\n",
"\n",
"loss, accuracy = test(net, centralized_loader)\n",
"print(f\"Final test set performance:\\n\\tloss {loss}\\n\\taccuracy {accuracy}\")"
]
},
{
"cell_type": "markdown",
"source": [
"You should see about 59% accuracy after the 5th epoch and 60% on the centralized dataset."
],
"metadata": {
"id": "mvLPIhI3fVfj"
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "kt4G4wHZ1CGC"
},
"source": [
"## Federated Learning\n",
"\n",
"Now, we'll move to implementing federated learning system.\n",
"\n",
"You will need to implement `FlowerClinet`, create Flower Strategy e.g. `FedAvg` and start simulation."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_CEdPhgl1CGC"
},
"source": [
"### Updating model parameters\n",
"\n",
"In federated learning, the server sends the global model parameters to the client, and the client updates the local model with the parameters received from the server. It then trains the model on the local data (which changes the model parameters locally) and sends the updated/changed model parameters back to the server (or, alternatively, it sends just the gradients back to the server, not the full model parameters).\n",
"\n",
"We need two helper functions to update the local model with parameters received from the server and to get the updated model parameters from the local model: `set_parameters` and `get_parameters`. The following two functions do just that for the PyTorch model above.\n",
"\n",
"The details of how this works are not really important here (feel free to consult the PyTorch documentation if you want to learn more). In essence, we use `state_dict` to access PyTorch model parameter tensors. The parameter tensors are then converted to/from a list of NumPy ndarray's (which Flower knows how to serialize/deserialize):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8aRP3XFq1CGC"
},
"outputs": [],
"source": [
"def get_parameters(net) -> List[np.ndarray]:\n",
" return [val.cpu().numpy() for _, val in net.state_dict().items()]\n",
"\n",
"\n",
"def set_parameters(net, parameters: List[np.ndarray]):\n",
" params_dict = zip(net.state_dict().keys(), parameters)\n",
" state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})\n",
" net.load_state_dict(state_dict, strict=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "m5zn88po1CGC"
},
"source": [
"### Implement a Flower client\n",
"\n",
"In Flower, we create clients by implementing subclasses of `flwr.client.Client` or `flwr.client.NumPyClient`. We use `NumPyClient` in this tutorial because it is easier to implement and requires us to write less boilerplate.\n",
"\n",
"To implement the Flower client, we create a subclass of `flwr.client.NumPyClient` and implement the three methods `get_parameters`, `fit`, and `evaluate`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cjLscSeF1CGC"
},
"outputs": [],
"source": [
"from flwr.common.typing import NDArrays\n",
"class FlowerClient(fl.client.NumPyClient):\n",
" \"\"\"\n",
" Class representing a single client in FL system, required to use Flower.\n",
" \"\"\"\n",
" def __init__(self, net: nn.Module, trainloader: DataLoader, valloader: DataLoader):\n",
" self.net = net\n",
" self.trainloader = trainloader\n",
" self.valloader = valloader\n",
"\n",
" def get_parameters(self, config):\n",
" \"\"\"Return the current local model parameters\"\"\"\n",
" return get_parameters(self.net)\n",
"\n",
" def fit(self, parameters: NDArrays, config: Dict[str, Scalar]) -> NDArrays:\n",
" \"\"\"Train the model on the local (train) data.\n",
"\n",
" Parameters\n",
" ----------\n",
" parameters: NDarrays\n",
" Model parameters (weights) received from the server\n",
"\n",
" config: Dict[str, Scalar]\n",
" Server based configuration (needed only if you require dynamically changing values).\n",
"\n",
" Returns\n",
" -------\n",
" NDArrays\n",
" Updated model parameters\n",
"\n",
" \"\"\"\n",
" set_parameters(self.net, parameters)\n",
" train(self.net, self.trainloader, epochs=1)\n",
" return get_parameters(self.net), len(self.trainloader), {}\n",
"\n",
" def evaluate(self, parameters: NDArrays, config: Dict[str, Scalar])-> Tuple[float, int, Dict[str, Scalar]]:\n",
" \"\"\"Evaluate model using the validation data.\n",
"\n",
" Parameters\n",
" ----------\n",
" parameters: NDarrays\n",
" Model parameters (weights) received from the server\n",
"\n",
" config: Dict[str, Scalar]\n",
" Server based configuration (needed only if you require dynamically changing values).\n",
"\n",
" Returns\n",
" -------\n",
" loss : float\n",
" The evaluation loss of the model on the local dataset.\n",
" num_examples : int\n",
" The number of examples used for evaluation.\n",
" metrics : Dict[str, Scalar]\n",
" A dictionary mapping arbitrary string keys to values of\n",
" type bool, bytes, float, int, or str. It can be used to\n",
" communicate arbitrary values back to the server.\n",
" \"\"\"\n",
" set_parameters(self.net, parameters)\n",
" loss, accuracy = test(self.net, self.valloader)\n",
" return float(loss), len(self.valloader), {\"accuracy\": float(accuracy)}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_5czKJ8U1CGC"
},
"source": [
"Our class `FlowerClient` defines how local training/evaluation will be performed and allows Flower to call the local training/evaluation through `fit` and `evaluate`. Each instance of `FlowerClient` represents a *single client* in our federated learning system. Federated learning systems have multiple clients (otherwise, there's not much to federate), so each client will be represented by its own instance of `FlowerClient`. If we have, for example, three clients in our workload, then we'd have three instances of `FlowerClient`. Flower calls `FlowerClient.fit` on the respective instance when the server selects a particular client for training (and `FlowerClient.evaluate` for evaluation).\n",
"\n",
"### Use the Virtual Client Engine\n",
"\n",
"We will simulate a federated learning system with 10 clients on a single machine = 10 instances of `FlowerClient` in memory. Doing this on a single machine.\n",
"\n",
"Flower creates `FlowerClient` instances only when they are actually necessary for training or evaluation by callling `client_fn` that returns a `FlowerClient` instance on demand. After using them for `fit` or `evaluate` they are discarded, so they should not keep any local state.\n",
"\n",
"`client_fn` takes a single argument `cid` - a client ID. The `cid` can be used, for example, to load different local data partitions for different clients."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qKtlrIZU1CGC"
},
"outputs": [],
"source": [
"# We can use the fact that we are using Jupyter Notebook environment and use the data without providing it as an argument.\n",
"def create_client_fn(cid: str) -> FlowerClient:\n",
" \"\"\"Create a Flower client representing a single organization.\"\"\"\n",
" net = Net().to(DEVICE)\n",
" trainloader = trainloaders[int(cid)]\n",
" valloader = validloaders[int(cid)]\n",
" return FlowerClient(net, trainloader, valloader)\n"
]
},
{
"cell_type": "markdown",
"source": [
"### Metrics Aggregation\n",
"Flower can automatically aggregate losses returned by individual clients, but it cannot do the same for metrics in the generic metrics dictionary (the one with the `accuracy` key). Metrics dictionaries can contain very different kinds of metrics and even key/value pairs that are not metrics at all, so the framework does not (and can not) know how to handle these automatically.\n",
"\n",
"The `weighted_average` function has to be passed to `evaluate_metrics_aggregation_fn` in your strategy."
],
"metadata": {
"id": "OnMl2zOTRPtO"
}
},
{
"cell_type": "code",
"source": [
"def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:\n",
" # Multiply accuracy of each client by number of examples used\n",
" accuracies = [num_examples * m[\"accuracy\"] for num_examples, m in metrics]\n",
" examples = [num_examples for num_examples, _ in metrics]\n",
"\n",
" # Aggregate and return custom metric (weighted average)\n",
" return {\"accuracy\": sum(accuracies) / sum(examples)}"
],
"metadata": {
"id": "rNTUJd4nRO56"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Create a Strategy\n",
"\n",
"Pick a strategy used for training. A good starting point is `FedAvg` but feel free to go throught the available strategies https://github.com/adap/flower/tree/main/src/py/flwr/server/strategy"
],
"metadata": {
"id": "HABHoeyq6E1m"
}
},
{
"cell_type": "code",
"source": [
"# Instantiate/Create a Flower strategy e.g. FedAvg\n",
"#TODO: Choose the strategy and specify the arguments\n",
"strategy = fl.server.strategy.FedAvg(\n",
" fraction_fit=1.0, # Sample 100% of available clients for training\n",
" fraction_evaluate=0.5, # Sample 50% of available clients for evaluation\n",
" min_fit_clients=10, # Never sample less than 10 clients for training\n",
" min_evaluate_clients=5, # Never sample less than 5 clients for evaluation\n",
" min_available_clients=10, # Wait until all 10 clients are available\n",
" evaluate_metrics_aggregation_fn=weighted_average,\n",
")"
],
"metadata": {
"id": "EERIycCO6eLA"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "TFWQwL_41CGC"
},
"source": [
"### Run Flower Simulation\n",
"\n",
"The function `flwr.simulation.start_simulation` accepts a number of arguments, amongst them the `client_fn` used to create `FlowerClient` instances, the number of clients to simulate (`num_clients`), the number of federated learning rounds (`num_rounds`), and the strategy. The strategy encapsulates the federated learning approach/algorithm, for example, *Federated Averaging* (FedAvg).\n",
"\n",
"Flower has a number of built-in strategies, but we can also use our own strategy implementations to customize nearly all aspects of the federated learning approach. For this example, we use the built-in `FedAvg` implementation and customize it using a few basic parameters. The last step is the actual call to `start_simulation` which - you guessed it - starts the simulation:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xBmL19b-1CGC",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "3651905a-9ec5-4bbf-928d-b4577adee3ae"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"INFO flwr 2023-05-31 11:14:27,876 | app.py:146 | Starting Flower simulation, config: ServerConfig(num_rounds=2, round_timeout=None)\n",
"INFO:flwr:Starting Flower simulation, config: ServerConfig(num_rounds=2, round_timeout=None)\n",
"2023-05-31 11:14:30,092\tINFO worker.py:1625 -- Started a local Ray instance.\n",
"INFO flwr 2023-05-31 11:14:31,517 | app.py:180 | Flower VCE: Ray initialized with resources: {'object_store_memory': 3897945292.0, 'memory': 7795890587.0, 'node:172.28.0.12': 1.0, 'CPU': 2.0}\n",
"INFO:flwr:Flower VCE: Ray initialized with resources: {'object_store_memory': 3897945292.0, 'memory': 7795890587.0, 'node:172.28.0.12': 1.0, 'CPU': 2.0}\n",
"INFO flwr 2023-05-31 11:14:31,529 | server.py:86 | Initializing global parameters\n",
"INFO:flwr:Initializing global parameters\n",
"INFO flwr 2023-05-31 11:14:31,530 | server.py:273 | Requesting initial parameters from one random client\n",
"INFO:flwr:Requesting initial parameters from one random client\n",
"\u001b[2m\u001b[36m(pid=2351)\u001b[0m 2023-05-31 11:14:32.953457: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
"INFO flwr 2023-05-31 11:14:35,998 | server.py:277 | Received initial parameters from one random client\n",
"INFO:flwr:Received initial parameters from one random client\n",
"INFO flwr 2023-05-31 11:14:36,002 | server.py:88 | Evaluating initial parameters\n",
"INFO:flwr:Evaluating initial parameters\n",
"INFO flwr 2023-05-31 11:14:36,005 | server.py:101 | FL starting\n",
"INFO:flwr:FL starting\n",
"DEBUG flwr 2023-05-31 11:14:36,008 | server.py:218 | fit_round 1: strategy sampled 10 clients (out of 10)\n",
"DEBUG:flwr:fit_round 1: strategy sampled 10 clients (out of 10)\n",
"DEBUG flwr 2023-05-31 11:17:15,276 | server.py:232 | fit_round 1 received 10 results and 0 failures\n",
"DEBUG:flwr:fit_round 1 received 10 results and 0 failures\n",
"WARNING flwr 2023-05-31 11:17:15,479 | fedavg.py:243 | No fit_metrics_aggregation_fn provided\n",
"WARNING:flwr:No fit_metrics_aggregation_fn provided\n",
"DEBUG flwr 2023-05-31 11:17:15,481 | server.py:168 | evaluate_round 1: strategy sampled 5 clients (out of 10)\n",
"DEBUG:flwr:evaluate_round 1: strategy sampled 5 clients (out of 10)\n",
"DEBUG flwr 2023-05-31 11:17:20,878 | server.py:182 | evaluate_round 1 received 5 results and 0 failures\n",
"DEBUG:flwr:evaluate_round 1 received 5 results and 0 failures\n",
"DEBUG flwr 2023-05-31 11:17:20,883 | server.py:218 | fit_round 2: strategy sampled 10 clients (out of 10)\n",
"DEBUG:flwr:fit_round 2: strategy sampled 10 clients (out of 10)\n",
"DEBUG flwr 2023-05-31 11:20:04,678 | server.py:232 | fit_round 2 received 10 results and 0 failures\n",
"DEBUG:flwr:fit_round 2 received 10 results and 0 failures\n",
"DEBUG flwr 2023-05-31 11:20:04,822 | server.py:168 | evaluate_round 2: strategy sampled 5 clients (out of 10)\n",
"DEBUG:flwr:evaluate_round 2: strategy sampled 5 clients (out of 10)\n",
"DEBUG flwr 2023-05-31 11:20:10,450 | server.py:182 | evaluate_round 2 received 5 results and 0 failures\n",
"DEBUG:flwr:evaluate_round 2 received 5 results and 0 failures\n",
"INFO flwr 2023-05-31 11:20:10,453 | server.py:147 | FL finished in 334.44558042200003\n",
"INFO:flwr:FL finished in 334.44558042200003\n",
"INFO flwr 2023-05-31 11:20:10,459 | app.py:218 | app_fit: losses_distributed [(1, 0.05944233487049739), (2, 0.016542544066905977)]\n",
"INFO:flwr:app_fit: losses_distributed [(1, 0.05944233487049739), (2, 0.016542544066905977)]\n",
"INFO flwr 2023-05-31 11:20:10,466 | app.py:219 | app_fit: metrics_distributed_fit {}\n",
"INFO:flwr:app_fit: metrics_distributed_fit {}\n",
"INFO flwr 2023-05-31 11:20:10,471 | app.py:220 | app_fit: metrics_distributed {'accuracy': [(1, 0.5691666666666667), (2, 0.794)]}\n",
"INFO:flwr:app_fit: metrics_distributed {'accuracy': [(1, 0.5691666666666667), (2, 0.794)]}\n",
"INFO flwr 2023-05-31 11:20:10,472 | app.py:221 | app_fit: losses_centralized []\n",
"INFO:flwr:app_fit: losses_centralized []\n",
"INFO flwr 2023-05-31 11:20:10,477 | app.py:222 | app_fit: metrics_centralized {}\n",
"INFO:flwr:app_fit: metrics_centralized {}\n"
]
}
],
"source": [
"client_resources = {\"num_cpus\": 2}\n",
"if DEVICE.type == \"cuda\":\n",
" client_resources[\"num_gpus\"] = 1\n",
"\n",
"\n",
"# Start simulation\n",
"history = fl.simulation.start_simulation(\n",
" client_fn=create_client_fn,\n",
" num_clients=NUM_CLIENTS,\n",
" config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),\n",
" strategy=strategy,\n",
" client_resources=client_resources,\n",
")"
]
},
{
"cell_type": "code",
"source": [
"history"
],
"metadata": {
"id": "KfMraHRcTKXu",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "7f8d6d0f-4de3-47d8-fc83-b83408dec014"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"History (loss, distributed):\n",
"\tround 1: 0.05944233487049739\n",
"\tround 2: 0.016542544066905977\n",
"History (metrics, distributed, evaluate):\n",
"{'accuracy': [(1, 0.5691666666666667), (2, 0.794)]}"
]
},
"metadata": {},
"execution_count": 21
}
]
},
{
"cell_type": "markdown",
"source": [
"## Centralized Evaluation\n",
"\n",
"Modification of the strategy is needed for centralized evaluation."
],
"metadata": {
"id": "oyRS7Hey9c7u"
}
},
{
"cell_type": "code",
"source": [
"def evaluate(\n",
" server_round: int,\n",
" parameters: fl.common.NDArrays,\n",
" config: Dict[str, fl.common.Scalar],\n",
" ) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:\n",
" \"\"\"Centralized evaluation function\"\"\"\n",
" net = Net().to(DEVICE)\n",
" set_parameters(net, parameters)\n",
" loss, accuracy = test(net, centralized_loader)\n",
" print(f\"Server-side evaluation loss {loss} / accuracy {accuracy}\")\n",
" return loss, {\"accuracy\": accuracy}\n",
"\n",
"# TODO: Specify the Strategy\n",
"strategy = fl.server.strategy.FedAvg(\n",
" fraction_fit=1.0, # Sample 100% of available clients for training\n",
" fraction_evaluate=0.5, # Sample 50% of available clients for evaluation\n",
" min_fit_clients=10, # Never sample less than 10 clients for training\n",
" min_evaluate_clients=5, # Never sample less than 5 clients for evaluation\n",
" min_available_clients=10, # Wait until all 10 clients are available\n",
" evaluate_metrics_aggregation_fn=weighted_average,\n",
" evaluate_fn=evaluate\n",
")\n",
"\n",
"# Start simulation\n",
"history = fl.simulation.start_simulation(\n",
" client_fn=create_client_fn,\n",
" num_clients=NUM_CLIENTS,\n",
" config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),\n",
" strategy=strategy,\n",
" client_resources=client_resources,\n",
")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "chLOL6VP9hXu",
"outputId": "b0fb0fc2-d12e-4ec7-91b3-ea4251a30e75"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"INFO flwr 2023-05-31 11:20:10,546 | app.py:146 | Starting Flower simulation, config: ServerConfig(num_rounds=2, round_timeout=None)\n",
"INFO:flwr:Starting Flower simulation, config: ServerConfig(num_rounds=2, round_timeout=None)\n",
"2023-05-31 11:20:14,979\tINFO worker.py:1625 -- Started a local Ray instance.\n",
"INFO flwr 2023-05-31 11:20:16,383 | app.py:180 | Flower VCE: Ray initialized with resources: {'node:172.28.0.12': 1.0, 'memory': 7804784640.0, 'CPU': 2.0, 'object_store_memory': 3902392320.0}\n",
"INFO:flwr:Flower VCE: Ray initialized with resources: {'node:172.28.0.12': 1.0, 'memory': 7804784640.0, 'CPU': 2.0, 'object_store_memory': 3902392320.0}\n",
"INFO flwr 2023-05-31 11:20:16,391 | server.py:86 | Initializing global parameters\n",
"INFO:flwr:Initializing global parameters\n",
"INFO flwr 2023-05-31 11:20:16,393 | server.py:273 | Requesting initial parameters from one random client\n",
"INFO:flwr:Requesting initial parameters from one random client\n",
"\u001b[2m\u001b[36m(pid=4070)\u001b[0m 2023-05-31 11:20:17.849262: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
"INFO flwr 2023-05-31 11:20:20,947 | server.py:277 | Received initial parameters from one random client\n",
"INFO:flwr:Received initial parameters from one random client\n",
"INFO flwr 2023-05-31 11:20:20,951 | server.py:88 | Evaluating initial parameters\n",
"INFO:flwr:Evaluating initial parameters\n",
"INFO flwr 2023-05-31 11:20:29,433 | server.py:91 | initial parameters (loss, other metrics): 0.07208683090209961, {'accuracy': 0.0917}\n",
"INFO:flwr:initial parameters (loss, other metrics): 0.07208683090209961, {'accuracy': 0.0917}\n",
"INFO flwr 2023-05-31 11:20:29,436 | server.py:101 | FL starting\n",
"INFO:flwr:FL starting\n",
"DEBUG flwr 2023-05-31 11:20:29,439 | server.py:218 | fit_round 1: strategy sampled 10 clients (out of 10)\n",
"DEBUG:flwr:fit_round 1: strategy sampled 10 clients (out of 10)\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Server-side evaluation loss 0.07208683090209961 / accuracy 0.0917\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"DEBUG flwr 2023-05-31 11:23:06,412 | server.py:232 | fit_round 1 received 10 results and 0 failures\n",
"DEBUG:flwr:fit_round 1 received 10 results and 0 failures\n",
"WARNING flwr 2023-05-31 11:23:06,569 | fedavg.py:243 | No fit_metrics_aggregation_fn provided\n",
"WARNING:flwr:No fit_metrics_aggregation_fn provided\n",
"INFO flwr 2023-05-31 11:23:14,960 | server.py:119 | fit progress: (1, 0.06729964368343354, {'accuracy': 0.6271}, 165.52080230600006)\n",
"INFO:flwr:fit progress: (1, 0.06729964368343354, {'accuracy': 0.6271}, 165.52080230600006)\n",
"DEBUG flwr 2023-05-31 11:23:14,968 | server.py:168 | evaluate_round 1: strategy sampled 5 clients (out of 10)\n",
"DEBUG:flwr:evaluate_round 1: strategy sampled 5 clients (out of 10)\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Server-side evaluation loss 0.06729964368343354 / accuracy 0.6271\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"DEBUG flwr 2023-05-31 11:23:20,438 | server.py:182 | evaluate_round 1 received 5 results and 0 failures\n",
"DEBUG:flwr:evaluate_round 1 received 5 results and 0 failures\n",
"DEBUG flwr 2023-05-31 11:23:20,443 | server.py:218 | fit_round 2: strategy sampled 10 clients (out of 10)\n",
"DEBUG:flwr:fit_round 2: strategy sampled 10 clients (out of 10)\n",
"DEBUG flwr 2023-05-31 11:26:00,031 | server.py:232 | fit_round 2 received 10 results and 0 failures\n",
"DEBUG:flwr:fit_round 2 received 10 results and 0 failures\n",
"INFO flwr 2023-05-31 11:26:07,903 | server.py:119 | fit progress: (2, 0.009333141782600432, {'accuracy': 0.9676}, 338.4640598540001)\n",
"INFO:flwr:fit progress: (2, 0.009333141782600432, {'accuracy': 0.9676}, 338.4640598540001)\n",
"DEBUG flwr 2023-05-31 11:26:07,907 | server.py:168 | evaluate_round 2: strategy sampled 5 clients (out of 10)\n",
"DEBUG:flwr:evaluate_round 2: strategy sampled 5 clients (out of 10)\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Server-side evaluation loss 0.009333141782600432 / accuracy 0.9676\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"DEBUG flwr 2023-05-31 11:26:13,931 | server.py:182 | evaluate_round 2 received 5 results and 0 failures\n",
"DEBUG:flwr:evaluate_round 2 received 5 results and 0 failures\n",
"INFO flwr 2023-05-31 11:26:13,934 | server.py:147 | FL finished in 344.49554322000006\n",
"INFO:flwr:FL finished in 344.49554322000006\n",
"INFO flwr 2023-05-31 11:26:13,950 | app.py:218 | app_fit: losses_distributed [(1, 0.06823034922281901), (2, 0.009110654158207278)]\n",
"INFO:flwr:app_fit: losses_distributed [(1, 0.06823034922281901), (2, 0.009110654158207278)]\n",
"INFO flwr 2023-05-31 11:26:13,955 | app.py:219 | app_fit: metrics_distributed_fit {}\n",
"INFO:flwr:app_fit: metrics_distributed_fit {}\n",
"INFO flwr 2023-05-31 11:26:13,961 | app.py:220 | app_fit: metrics_distributed {'accuracy': [(1, 0.6213333333333334), (2, 0.9670000000000001)]}\n",
"INFO:flwr:app_fit: metrics_distributed {'accuracy': [(1, 0.6213333333333334), (2, 0.9670000000000001)]}\n",
"INFO flwr 2023-05-31 11:26:13,964 | app.py:221 | app_fit: losses_centralized [(0, 0.07208683090209961), (1, 0.06729964368343354), (2, 0.009333141782600432)]\n",
"INFO:flwr:app_fit: losses_centralized [(0, 0.07208683090209961), (1, 0.06729964368343354), (2, 0.009333141782600432)]\n",
"INFO flwr 2023-05-31 11:26:13,965 | app.py:222 | app_fit: metrics_centralized {'accuracy': [(0, 0.0917), (1, 0.6271), (2, 0.9676)]}\n",
"INFO:flwr:app_fit: metrics_centralized {'accuracy': [(0, 0.0917), (1, 0.6271), (2, 0.9676)]}\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"history"
],
"metadata": {
"id": "1WMIs9lvE_-y",
"outputId": "e652064b-83b5-446c-abd9-6c1e8318ec5c",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"History (loss, distributed):\n",
"\tround 1: 0.06823034922281901\n",
"\tround 2: 0.009110654158207278\n",
"History (loss, centralized):\n",
"\tround 0: 0.07208683090209961\n",
"\tround 1: 0.06729964368343354\n",
"\tround 2: 0.009333141782600432\n",
"History (metrics, distributed, evaluate):\n",
"{'accuracy': [(1, 0.6213333333333334), (2, 0.9670000000000001)]}History (metrics, centralized):\n",
"{'accuracy': [(0, 0.0917), (1, 0.6271), (2, 0.9676)]}"
]
},
"metadata": {},
"execution_count": 23
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jyex7uv31CGG"
},
"source": [
"## Final remarks\n",
"\n",
"Congratulations, you just trained a convolutional neural network, federated over 10 clients! With that, you understand the basics of federated learning with Flower. The same approach you've seen can be used with other machine learning frameworks (not just PyTorch) and tasks (not just MNIST images classification), for example NLP with Hugging Face Transformers or speech with SpeechBrain."
]
}
],
"metadata": {
"colab": {
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "flower-3.7.12",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment