Skip to content

Instantly share code, notes, and snippets.

@kirarpit
Created March 15, 2022 20:01
Show Gist options
  • Save kirarpit/e7b18200a7540f790c8926266e079d32 to your computer and use it in GitHub Desktop.
Save kirarpit/e7b18200a7540f790c8926266e079d32 to your computer and use it in GitHub Desktop.
[SHARED] Central vs Local EMNIST
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/kirarpit/e7b18200a7540f790c8926266e079d32/-shared-central-vs-local-emnist.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"## Before we start\n",
"\n",
"Before we start, please run the following to make sure that your environment is\n",
"correctly setup. If you don't see a greeting, please refer to the\n",
"[Installation](../install.md) guide for instructions. "
],
"metadata": {
"id": "h8SiGeczTT1p"
}
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fyYJeBKZXnRh",
"outputId": "c728710a-a494-406f-c054-297c03d36404"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[?25l\r\u001b[K |▍ | 10 kB 18.0 MB/s eta 0:00:01\r\u001b[K |▉ | 20 kB 18.1 MB/s eta 0:00:01\r\u001b[K |█▏ | 30 kB 7.4 MB/s eta 0:00:01\r\u001b[K |█▋ | 40 kB 3.9 MB/s eta 0:00:01\r\u001b[K |██ | 51 kB 3.8 MB/s eta 0:00:01\r\u001b[K |██▍ | 61 kB 4.6 MB/s eta 0:00:01\r\u001b[K |██▉ | 71 kB 4.6 MB/s eta 0:00:01\r\u001b[K |███▏ | 81 kB 4.7 MB/s eta 0:00:01\r\u001b[K |███▋ | 92 kB 5.3 MB/s eta 0:00:01\r\u001b[K |████ | 102 kB 4.3 MB/s eta 0:00:01\r\u001b[K |████▍ | 112 kB 4.3 MB/s eta 0:00:01\r\u001b[K |████▉ | 122 kB 4.3 MB/s eta 0:00:01\r\u001b[K |█████▏ | 133 kB 4.3 MB/s eta 0:00:01\r\u001b[K |█████▋ | 143 kB 4.3 MB/s eta 0:00:01\r\u001b[K |██████ | 153 kB 4.3 MB/s eta 0:00:01\r\u001b[K |██████▍ | 163 kB 4.3 MB/s eta 0:00:01\r\u001b[K |██████▉ | 174 kB 4.3 MB/s eta 0:00:01\r\u001b[K |███████▏ | 184 kB 4.3 MB/s eta 0:00:01\r\u001b[K |███████▋ | 194 kB 4.3 MB/s eta 0:00:01\r\u001b[K |████████ | 204 kB 4.3 MB/s eta 0:00:01\r\u001b[K |████████▍ | 215 kB 4.3 MB/s eta 0:00:01\r\u001b[K |████████▉ | 225 kB 4.3 MB/s eta 0:00:01\r\u001b[K |█████████▏ | 235 kB 4.3 MB/s eta 0:00:01\r\u001b[K |█████████▋ | 245 kB 4.3 MB/s eta 0:00:01\r\u001b[K |██████████ | 256 kB 4.3 MB/s eta 0:00:01\r\u001b[K |██████████▍ | 266 kB 4.3 MB/s eta 0:00:01\r\u001b[K |██████████▉ | 276 kB 4.3 MB/s eta 0:00:01\r\u001b[K |███████████▏ | 286 kB 4.3 MB/s eta 0:00:01\r\u001b[K |███████████▋ | 296 kB 4.3 MB/s eta 0:00:01\r\u001b[K |████████████ | 307 kB 4.3 MB/s eta 0:00:01\r\u001b[K |████████████▍ | 317 kB 4.3 MB/s eta 0:00:01\r\u001b[K |████████████▉ | 327 kB 4.3 MB/s eta 0:00:01\r\u001b[K |█████████████▏ | 337 kB 4.3 MB/s eta 0:00:01\r\u001b[K |█████████████▋ | 348 kB 4.3 MB/s eta 0:00:01\r\u001b[K |██████████████ | 358 kB 4.3 MB/s eta 0:00:01\r\u001b[K |██████████████▍ | 368 kB 4.3 MB/s eta 0:00:01\r\u001b[K |██████████████▉ | 378 kB 4.3 MB/s eta 0:00:01\r\u001b[K |███████████████▏ | 389 kB 4.3 MB/s eta 0:00:01\r\u001b[K |███████████████▋ | 399 kB 4.3 MB/s eta 0:00:01\r\u001b[K |████████████████ | 409 kB 4.3 MB/s eta 0:00:01\r\u001b[K |████████████████▍ | 419 kB 4.3 MB/s eta 0:00:01\r\u001b[K |████████████████▉ | 430 kB 4.3 MB/s eta 0:00:01\r\u001b[K |█████████████████▏ | 440 kB 4.3 MB/s eta 0:00:01\r\u001b[K |█████████████████▋ | 450 kB 4.3 MB/s eta 0:00:01\r\u001b[K |██████████████████ | 460 kB 4.3 MB/s eta 0:00:01\r\u001b[K |██████████████████▍ | 471 kB 4.3 MB/s eta 0:00:01\r\u001b[K |██████████████████▉ | 481 kB 4.3 MB/s eta 0:00:01\r\u001b[K |███████████████████▏ | 491 kB 4.3 MB/s eta 0:00:01\r\u001b[K |███████████████████▋ | 501 kB 4.3 MB/s eta 0:00:01\r\u001b[K |████████████████████ | 512 kB 4.3 MB/s eta 0:00:01\r\u001b[K |████████████████████▍ | 522 kB 4.3 MB/s eta 0:00:01\r\u001b[K |████████████████████▉ | 532 kB 4.3 MB/s eta 0:00:01\r\u001b[K |█████████████████████▏ | 542 kB 4.3 MB/s eta 0:00:01\r\u001b[K |█████████████████████▋ | 552 kB 4.3 MB/s eta 0:00:01\r\u001b[K |██████████████████████ | 563 kB 4.3 MB/s eta 0:00:01\r\u001b[K |██████████████████████▍ | 573 kB 4.3 MB/s eta 0:00:01\r\u001b[K |██████████████████████▉ | 583 kB 4.3 MB/s eta 0:00:01\r\u001b[K |███████████████████████▏ | 593 kB 4.3 MB/s eta 0:00:01\r\u001b[K |███████████████████████▋ | 604 kB 4.3 MB/s eta 0:00:01\r\u001b[K |████████████████████████ | 614 kB 4.3 MB/s eta 0:00:01\r\u001b[K |████████████████████████▍ | 624 kB 4.3 MB/s eta 0:00:01\r\u001b[K |████████████████████████▉ | 634 kB 4.3 MB/s eta 0:00:01\r\u001b[K |█████████████████████████▏ | 645 kB 4.3 MB/s eta 0:00:01\r\u001b[K |█████████████████████████▋ | 655 kB 4.3 MB/s eta 0:00:01\r\u001b[K |██████████████████████████ | 665 kB 4.3 MB/s eta 0:00:01\r\u001b[K |██████████████████████████▍ | 675 kB 4.3 MB/s eta 0:00:01\r\u001b[K |██████████████████████████▊ | 686 kB 4.3 MB/s eta 0:00:01\r\u001b[K |███████████████████████████▏ | 696 kB 4.3 MB/s eta 0:00:01\r\u001b[K |███████████████████████████▋ | 706 kB 4.3 MB/s eta 0:00:01\r\u001b[K |████████████████████████████ | 716 kB 4.3 MB/s eta 0:00:01\r\u001b[K |████████████████████████████▍ | 727 kB 4.3 MB/s eta 0:00:01\r\u001b[K |████████████████████████████▊ | 737 kB 4.3 MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▏ | 747 kB 4.3 MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▋ | 757 kB 4.3 MB/s eta 0:00:01\r\u001b[K |██████████████████████████████ | 768 kB 4.3 MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▍ | 778 kB 4.3 MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▊ | 788 kB 4.3 MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▏| 798 kB 4.3 MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▋| 808 kB 4.3 MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 819 kB 4.3 MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 819 kB 4.3 MB/s \n",
"\u001b[K |████████████████████████████████| 234 kB 46.6 MB/s \n",
"\u001b[K |████████████████████████████████| 53 kB 322 kB/s \n",
"\u001b[K |████████████████████████████████| 887 kB 17.5 MB/s \n",
"\u001b[K |████████████████████████████████| 4.0 MB 23.6 MB/s \n",
"\u001b[K |████████████████████████████████| 121 kB 43.4 MB/s \n",
"\u001b[K |████████████████████████████████| 65.1 MB 80 kB/s \n",
"\u001b[K |████████████████████████████████| 45 kB 2.3 MB/s \n",
"\u001b[K |████████████████████████████████| 251 kB 39.4 MB/s \n",
"\u001b[K |████████████████████████████████| 462 kB 35.7 MB/s \n",
"\u001b[K |████████████████████████████████| 4.2 MB 35.1 MB/s \n",
"\u001b[?25h Building wheel for jax (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"spacy 2.2.4 requires tqdm<5.0.0,>=4.38.0, but you have tqdm 4.28.1 which is incompatible.\n",
"pymc3 3.11.4 requires cachetools>=4.2.1, but you have cachetools 3.1.1 which is incompatible.\n",
"panel 0.12.1 requires tqdm>=4.48.0, but you have tqdm 4.28.1 which is incompatible.\n",
"fbprophet 0.7.1 requires tqdm>=4.36.1, but you have tqdm 4.28.1 which is incompatible.\n",
"datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.\u001b[0m\n"
]
}
],
"source": [
"!pip install --quiet --upgrade tensorflow-federated\n",
"!pip install --quiet --upgrade nest-asyncio\n",
"\n",
"import nest_asyncio\n",
"nest_asyncio.apply()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "ZpTgt9YwXxuS"
},
"outputs": [],
"source": [
"%load_ext tensorboard"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "AaOqWEWKXuua",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "8a410c48-332c-4e29-a1d1-5827fbb41378"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Mounted at /content/gdrive\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"b'Hello, World!'"
]
},
"metadata": {},
"execution_count": 3
}
],
"source": [
"import collections\n",
"\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"import tensorflow_federated as tff\n",
"import matplotlib as mpl\n",
"from matplotlib import pyplot as plt\n",
"import pickle\n",
"import os\n",
"from google.colab import drive\n",
"drive.mount('/content/gdrive')\n",
"\n",
"font = {\"family\": \"DejaVu Sans\", \"weight\": \"bold\", \"size\": 18}\n",
"mpl.rc(\"font\", **font)\n",
"\n",
"np.random.seed(0)\n",
"\n",
"tff.federated_computation(lambda: 'Hello, World!')()"
]
},
{
"cell_type": "markdown",
"source": [
"## Understanding the data\n",
"\n",
"Let's start with the data. Federated learning requires a federated data set,\n",
"i.e., a collection of data from multiple users. Federated data is typically\n",
"non-[i.i.d.](https://en.wikipedia.org/wiki/Independent_and_identically_distributed_random_variables),\n",
"which poses a unique set of challenges.\n",
"\n",
"In order to facilitate experimentation, we seeded the TFF repository with a few\n",
"datasets, including a federated version of MNIST that contains a version of the [original NIST dataset](https://www.nist.gov/srd/nist-special-database-19) that has been re-processed using [Leaf](https://github.com/TalwalkarLab/leaf) so that the data is keyed by the original writer of the digits. Since each writer has a unique style, this dataset exhibits the kind of non-i.i.d. behavior expected of federated datasets.\n",
"\n",
"Here's how we can load it."
],
"metadata": {
"id": "tpXABKLdTeXo"
}
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "va9rH4rEX1-Z",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "0693bb15-ec16-468d-c232-f5df4d4fc85b"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Downloading emnist_all.sqlite.lzma: 100%|██████████| 170507172/170507172 [01:00<00:00, 2985562.81it/s]\n"
]
}
],
"source": [
"emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "znn3RlnlX3Pa",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 248
},
"outputId": "5688de0b-ddb0-40e6-e09e-28b486508343"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1440x288 with 40 Axes>"
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
],
"source": [
"example_dataset = emnist_train.create_tf_dataset_for_client(\n",
" emnist_train.client_ids[2])\n",
"\n",
"## Example MNIST digits for one client\n",
"figure = plt.figure(figsize=(20, 4))\n",
"j = 0\n",
"\n",
"for example in example_dataset.take(40):\n",
" plt.subplot(4, 10, j+1)\n",
" plt.imshow(example['pixels'].numpy(), cmap='gray', aspect='equal')\n",
" plt.axis('off')\n",
" j += 1"
]
},
{
"cell_type": "markdown",
"source": [
"### Mean of each pixel value for all of the user's examples per label\n",
"Let's visualize the mean image per client for each MNIST label. This code will produce the mean of each pixel value for all of the user's examples for one label. We'll see that one client's mean image for a digit will look different than another client's mean image for the same digit, due to each person's unique handwriting style. We can muse about how each local training round will nudge the model in a different direction on each client, as we're learning from that user's own unique data in that local round. Later in the tutorial we'll see how we can take each update to the model from all the clients and aggregate them together into our new global model, that has learned from each of our client's own unique data."
],
"metadata": {
"id": "0CHvUABGWZaH"
}
},
{
"cell_type": "code",
"source": [
"# Each client has different mean images, meaning each client will be nudging\n",
"# the model in their own directions locally.\n",
"\n",
"for i in [0, 2, 3, 9]:\n",
" client_dataset = emnist_train.create_tf_dataset_for_client(\n",
" emnist_train.client_ids[i])\n",
" plot_data = collections.defaultdict(list)\n",
" for example in client_dataset:\n",
" plot_data[example['label'].numpy()].append(example['pixels'].numpy())\n",
" f = plt.figure(i, figsize=(12, 5))\n",
" f.suptitle(\"Client #{}'s Mean Image Per Label\".format(i))\n",
" for j in range(10):\n",
" mean_img = np.mean(plot_data[j], 0)\n",
" plt.subplot(2, 5, j+1)\n",
" plt.imshow(mean_img.reshape((28, 28)))\n",
" plt.axis('off')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "47J6JVzAWdXH",
"outputId": "41535f43-bb2e-44e1-809f-251c6570e773"
},
"execution_count": 6,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 864x360 with 10 Axes>"
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 864x360 with 10 Axes>"
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 864x360 with 10 Axes>"
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 864x360 with 10 Axes>"
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"source": [
"## Data pre-processing and model definition\n",
"In this block we define preprocessing function, load federated and central datasets and define model function."
],
"metadata": {
"id": "F04fW8PbYUbW"
}
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "YaPRjvpsYhmq"
},
"outputs": [],
"source": [
"NUM_CLIENTS = 10\n",
"BATCH_SIZE = 20\n",
"SHUFFLE_BUFFER = 100\n",
"PREFETCH_BUFFER = 10\n",
"\n",
"def preprocess(dataset):\n",
" def batch_format_fn(element):\n",
" \"\"\"Flatten a batch `pixels` and return the features as an `OrderedDict`.\"\"\"\n",
" return collections.OrderedDict(\n",
" x=tf.reshape(element['pixels'], [-1, 784]),\n",
" y=tf.reshape(element['label'], [-1, 1]))\n",
"\n",
" return dataset.shuffle(SHUFFLE_BUFFER, seed=1).batch(BATCH_SIZE).map(\n",
" batch_format_fn).prefetch(PREFETCH_BUFFER)\n",
"\n",
"def make_federated_data(client_data, client_ids):\n",
" return [\n",
" preprocess(client_data.create_tf_dataset_for_client(x))\n",
" for x in client_ids\n",
" ]\n",
"\n",
"def make_central_data(client_data, client_ids):\n",
" clients_data = [client_data.create_tf_dataset_for_client(client) for client in client_ids]\n",
" central_data = clients_data[0]\n",
" for data in clients_data[1:]:\n",
" central_data = central_data.concatenate(data)\n",
" return preprocess(central_data)\n",
"\n",
"def create_keras_model():\n",
" return tf.keras.models.Sequential([\n",
" tf.keras.layers.InputLayer(input_shape=(784,)),\n",
" tf.keras.layers.Dense(10, kernel_initializer='zeros'),\n",
" tf.keras.layers.Softmax(),\n",
" ])\n",
"\n",
"sampled_clients = emnist_train.client_ids[0:NUM_CLIENTS]\n",
"federated_train_data = make_federated_data(emnist_train, sampled_clients)\n",
"federated_test_data = make_federated_data(emnist_test, sampled_clients)\n",
"central_train_data = make_central_data(emnist_train, sampled_clients)\n",
"central_test_data = make_central_data(emnist_test, sampled_clients)\n",
"\n",
"def model_fn():\n",
" # We _must_ create a new model here, and _not_ capture it from an external\n",
" # scope. TFF will call this within different graph contexts.\n",
" keras_model = create_keras_model()\n",
" return tff.learning.from_keras_model(\n",
" keras_model,\n",
" input_spec=federated_train_data[0].element_spec,\n",
" loss=tf.keras.losses.SparseCategoricalCrossentropy(),\n",
" metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])"
]
},
{
"cell_type": "markdown",
"source": [
"## Defining federated averaging iterative process"
],
"metadata": {
"id": "Xtx6u0pPaAaU"
}
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "g0XvkRs8Y3ad"
},
"outputs": [],
"source": [
"EXP_ID = 4\n",
"NUM_ROUNDS = 500\n",
"\n",
"iterative_process = tff.learning.build_federated_averaging_process(\n",
" model_fn,\n",
" client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),\n",
" server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))\n",
"evaluation = tff.learning.build_federated_evaluation(model_fn)\n",
"\n",
"def log_metrics(train_metric_dict, test_metric_dict):\n",
" train_acc = train_metric_dict['train']['sparse_categorical_accuracy']\n",
" train_loss = train_metric_dict['train']['loss']\n",
" test_acc = test_metric_dict['eval']['sparse_categorical_accuracy']\n",
" test_loss = test_metric_dict['eval']['loss']\n",
"\n",
" tf.summary.scalar('train_loss', train_loss, step=round_num)\n",
" tf.summary.scalar('train_acc', train_acc, step=round_num)\n",
" tf.summary.scalar('test_loss', test_loss, step=round_num)\n",
" tf.summary.scalar('test_acc', test_acc, step=round_num)\n"
]
},
{
"cell_type": "markdown",
"source": [
"## Training an FL model"
],
"metadata": {
"id": "ehQoqR2H9Dkk"
}
},
{
"cell_type": "code",
"source": [
"logdir = f\"/tmp/logs/scalars/training/fl_expid{EXP_ID}\"\n",
"summary_writer = tf.summary.create_file_writer(logdir)\n",
"\n",
"fl_model_state = iterative_process.initialize()\n",
"with summary_writer.as_default():\n",
" for round_num in range(0, NUM_ROUNDS):\n",
" fl_model_state, metrics = iterative_process.next(fl_model_state, federated_train_data)\n",
" metric_dict = evaluation(fl_model_state.model, federated_test_data)\n",
" print(f\"round #: {round_num}, metrics: {metric_dict}\")\n",
" log_metrics(metrics, metric_dict)\n",
"\n",
"\n",
"directory = f'/content/gdrive/My Drive/saved_states/{EXP_ID}'\n",
"if not os.path.exists(directory):\n",
" os.mkdir(directory)\n",
"with open(f\"/content/gdrive/My Drive/saved_states/{EXP_ID}/fl\", \"wb\") as f:\n",
" pickle.dump(fl_model_state, f)"
],
"metadata": {
"id": "xnGc4tHb9DNe"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!ls {logdir}\n",
"%tensorboard --logdir {logdir} --port=0"
],
"metadata": {
"colab": {
"resources": {
"https://localhost:34825/?tensorboardColab=true": {
"data": "",
"ok": true,
"headers": [
[
"content-type",
"text/html; charset=utf-8"
]
],
"status": 200,
"status_text": ""
},
"https://localhost:34825/index.js?_file_hash=4a366cbe": {
"data": "
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment