Skip to content

Instantly share code, notes, and snippets.

@mbacvanski
Created December 14, 2023 16:49
Show Gist options
  • Save mbacvanski/71909a86218400aa388832a6071c53df to your computer and use it in GitHub Desktop.
Save mbacvanski/71909a86218400aa388832a6071c53df to your computer and use it in GitHub Desktop.
Forward-Forward Algorithm with MNIST
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/mbacvanski/71909a86218400aa388832a6071c53df/forward_forward_mnist_with_explanations.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-03T22:04:09.192145Z",
"start_time": "2023-12-03T22:04:08.487888Z"
},
"id": "4TgCLCbPlP0b"
},
"outputs": [],
"source": [
"from typing import List\n",
"\n",
"import torch\n",
"from tqdm import tqdm\n",
"\n",
"DEVICE = 'cpu'\n",
"\n",
"metal = torch.device(DEVICE)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3d61H1h8ldmY"
},
"source": [
"The forward-forward architecture's goal is to train each layer individually, so that each layer responds most \"positively\" to positive data and most \"negatively\" to negative data. There are several ways we could define responding \"positively\" versus \"negatively\", but the simplest approach could be just the L2 norm (magnitude) of the vector outputted by the layer.\n",
"\n",
"Let's define a single layer first. This is very similar to a traditional linear layer, where the output is computed as the activation function applied to the intermediate matrix multiplication, $\\sigma(z)$ where the $z=w^Tx+b$ with $w$ being the weights of this layer $x$ being the input, and $b$ the bias."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-03T22:13:00.193340Z",
"start_time": "2023-12-03T22:13:00.183319Z"
},
"id": "WhJrOyDVmiaC"
},
"outputs": [],
"source": [
"class Layer(torch.nn.Linear):\n",
" def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):\n",
" super().__init__(in_features, out_features, bias, device, dtype)\n",
"\n",
" # define the components we will use\n",
" self.activation = torch.nn.ReLU()\n",
" self.optimizer = torch.optim.Adam(self.parameters(), lr=0.001)\n",
"\n",
" # define hyperparameters\n",
" self.threshold = 2.0\n",
" self.num_epochs = 20\n",
"\n",
" # keep track of losses during training\n",
" self.losses = []\n",
"\n",
" def forward(self, x):\n",
" \"\"\"\n",
" Here we define the forward pass through this layer. This is the same as how a linear layer does it,\n",
" taking the matrix multiplication of x^T and the weights and adding the bias term. The difference is that\n",
" here we integrate normalization into the layer, so we first take only the unit vector of the input.\n",
"\n",
" Remember that because the previous layer will be outputting a vector of differing magnitude depending on how\n",
" 'excited' it is about its input, taking the norm of the vector guarantees that only the information about\n",
" direction of the output vector from the previous layer is used to determine how excited this layer is about\n",
" the input.\n",
" \"\"\"\n",
" x_normalized = x / (x.norm(2, 1, keepdim=True) + 1e-8)\n",
" # w^T + b\n",
" return self.activation(\n",
" torch.mm(x_normalized, self.weight.T) + self.bias.unsqueeze(0)\n",
" )\n",
"\n",
" def train(self, x_positive, x_negative, verbose=True):\n",
" \"\"\"\n",
" Each layer of the forward-forward network can train itself, because it does not rely on gradients from other\n",
" layers in the network. All it needs is a source of positive and negative data. The first layer of the network\n",
" receives data as the concatenation of X and Y (data and label), while subsequent layers receive the output\n",
" of the previous layer's forward pass on the positive and negative data respectfully.\n",
" \"\"\"\n",
" self.losses = []\n",
" with tqdm(range(self.num_epochs), disable=not verbose) as pbar:\n",
" for _ in pbar:\n",
" goodness_positive = self.forward(x_positive).pow(2).mean(1)\n",
" goodness_negative = self.forward(x_negative).pow(2).mean(1)\n",
"\n",
" \"\"\"\n",
" We want to minimize loss, which we can define in a vector of two parts.\n",
" The first part minimizes the -(goodness - threshold) for positive data, which maximizes goodness for positive data\n",
" The second part minimizes the (goodness - threshold) for negative data, which directly minimizes goodness for negative data\n",
" Threshold is some margin we want to keep between the goodness of positive and negative data.\n",
" Potential for a dual formulation of this optimization problem as well.\n",
" \"\"\"\n",
"\n",
" loss = torch.log(1 + torch.exp(torch.cat([\n",
" -goodness_positive + self.threshold,\n",
" goodness_negative - self.threshold]))).mean()\n",
"\n",
" \"\"\"\n",
" Notice that even though we use the gradient, this gradient is only locally computed within this layer.\n",
" \"\"\"\n",
" self.optimizer.zero_grad()\n",
" loss.backward()\n",
" self.optimizer.step()\n",
"\n",
" self.losses.append(loss.item())\n",
"\n",
" pbar.set_description(f'Loss: {loss.item():.4f}')\n",
"\n",
" \"\"\"\n",
" The goodness computed in this layer's forward pass is needed to train the next layer. This is what the next\n",
" layer will use as its x_positive and x_negative.\n",
" \"\"\"\n",
" return self.forward(x_positive).detach(), self.forward(x_negative).detach()\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "n9pvVAk-5Whe"
},
"source": [
"Now that we've defined a single layer, let's combine some layers into a multi-layer network."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-03T22:09:12.933892Z",
"start_time": "2023-12-03T22:09:12.927801Z"
},
"id": "2H_psk3v6UaX"
},
"outputs": [],
"source": [
"def combine_x_y(x, y):\n",
" \"\"\"\n",
" The forward-forward network expects inputs to contain both the input data (x) as well as the label (y). This method\n",
" concatenates the data and label together. Note there are 10 possible classes.\n",
" \"\"\"\n",
" x_ = x.clone()\n",
" label_onehot = torch.nn.functional.one_hot(y, 10) * x.max()\n",
" return torch.hstack([x, label_onehot])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-03T22:13:03.395731Z",
"start_time": "2023-12-03T22:13:03.390410Z"
},
"id": "dgwNiUuc5NBq"
},
"outputs": [],
"source": [
"class FFNN(torch.nn.Module):\n",
" def __init__(self, dims):\n",
" super().__init__()\n",
" self.layers: List[Layer] = []\n",
"\n",
" for d in range(len(dims) - 1):\n",
" self.layers.append(Layer(dims[d], dims[d + 1]).to(device=DEVICE))\n",
"\n",
" def predict(self, x):\n",
" goodness_per_label = []\n",
" for label in range(10):\n",
" \"\"\" Generate an array of just that label, since we want to append that label to all x \"\"\"\n",
" label_arr = [label] * x.shape[0]\n",
" input = combine_x_y(x, torch.tensor(label_arr).to(device=DEVICE))\n",
"\n",
" \"\"\" Run this input through the network and record the total activation across the whole network \"\"\"\n",
" goodnesses_per_layer = self._forward_pass(input)\n",
" goodness_per_label.append(sum(goodnesses_per_layer).unsqueeze(1))\n",
"\n",
" \"\"\" The class predicted is the class of data that achieves the highest activation. With X staying constant,\n",
" the only difference comes from the label appended to X. \"\"\"\n",
" goodness_per_label = torch.cat(goodness_per_label, 1)\n",
" return goodness_per_label.argmax(1)\n",
"\n",
" def train(self, x_positive, x_negative, n_epochs, mode='batch', validation_x=None, validation_y=None):\n",
" \"\"\" Use the activations from the previous layer for the positive and negative data, respectfully, to train the\n",
" subsequent layers. After every n iterations of training, evaluate the whole model's accuracy on the validation data. \"\"\"\n",
" self._history = {'accuracy': []}\n",
" if mode == 'batch':\n",
" self._train_batch(x_positive, x_negative)\n",
" elif mode == 'sample':\n",
" self._train_sample(x_positive, x_negative, validation_x, validation_y, n_epochs=20)\n",
"\n",
" def _train_batch(self, x_positive, x_negative):\n",
" h_positive, h_negative = x_positive, x_negative\n",
" for i, layer in enumerate(self.layers):\n",
" print('Training layer', i)\n",
" h_positive, h_negative = layer.train(h_positive, h_negative)\n",
"\n",
" def _train_sample(self, x_positive, x_negative, validation_x, validation_y, n_epochs):\n",
" for epoch in range(n_epochs):\n",
" print('Epoch', epoch)\n",
" pbar = tqdm(range(len(x_positive)), desc='Training samples')\n",
" for i in pbar:\n",
" h_positive_sample, h_negative_sample = x_positive[i].unsqueeze(0), x_negative[i].unsqueeze(0)\n",
" for j, layer in enumerate(self.layers):\n",
" h_positive_sample, h_negative_sample = layer.train(h_positive_sample, h_negative_sample, verbose=False)\n",
"\n",
" if i % 10 == 0 and validation_x is not None and validation_y is not None:\n",
" accuracy = self.predict(validation_x).eq(validation_y).float().mean().item()\n",
" self._history['accuracy'].append(accuracy)\n",
" pbar.set_postfix({'Validation accuracy': accuracy})\n",
"\n",
" def get_losses(self):\n",
" return [l.losses for l in self.layers]\n",
"\n",
" def _forward_pass(self, input):\n",
" h = input\n",
" goodnesses = []\n",
" for layer in self.layers:\n",
" \"\"\" Goodness is computed as the magnitude of the activations of this layer \"\"\"\n",
" activation = layer(h)\n",
" activation_magnitude = activation.pow(2).mean(1)\n",
" goodnesses.append(activation_magnitude)\n",
"\n",
" \"\"\" Use the activation of this layer as the input to the next layer \"\"\"\n",
" h = activation\n",
"\n",
" return goodnesses"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Kid_R70k-8Bb"
},
"source": [
"Now let's define our dataloaders for MNIST data, with a training and dataset that are both normalized."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-03T22:09:13.233675Z",
"start_time": "2023-12-03T22:09:13.230446Z"
},
"id": "umico8I1_lkH"
},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"from torchvision.transforms import Compose, ToTensor, Normalize, Lambda\n",
"from torchvision.datasets import MNIST\n",
"\n",
"\n",
"def MNIST_loaders(train_batch_size=50000, test_batch_size=10000):\n",
" transform = Compose([\n",
" ToTensor(),\n",
" Normalize((0.1307,), (0.3081,)), # mean and std of the MNIST dataset\n",
" Lambda(lambda x: torch.flatten(x))])\n",
"\n",
" train_loader = DataLoader(\n",
" MNIST('./data/', train=True, download=True, transform=transform),\n",
" batch_size=train_batch_size, shuffle=True)\n",
"\n",
" test_loader = DataLoader(\n",
" MNIST('./data/', train=False, download=True, transform=transform),\n",
" batch_size=test_batch_size, shuffle=False)\n",
"\n",
" return train_loader, test_loader\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-03T22:09:13.715779Z",
"start_time": "2023-12-03T22:09:13.688937Z"
},
"id": "5s8mhh_s-7Dv"
},
"outputs": [],
"source": [
"torch.manual_seed(42)\n",
"train_loader, test_loader = MNIST_loaders()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-03T22:04:24.417392Z",
"start_time": "2023-12-03T22:04:22.193552Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "bUk_OpUAAJ13",
"outputId": "4e7efa49-d9ec-4fb3-8f15-1c091b50c429"
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.4242, -0.4242, -0.4242, ..., 0.0000, 0.0000, 0.0000],\n",
" [-0.4242, -0.4242, -0.4242, ..., 0.0000, 0.0000, 0.0000],\n",
" [-0.4242, -0.4242, -0.4242, ..., 0.0000, 2.8215, 0.0000],\n",
" ...,\n",
" [-0.4242, -0.4242, -0.4242, ..., 0.0000, 0.0000, 0.0000],\n",
" [-0.4242, -0.4242, -0.4242, ..., 0.0000, 0.0000, 0.0000],\n",
" [-0.4242, -0.4242, -0.4242, ..., 2.8215, 0.0000, 0.0000]])"
]
},
"execution_count": 84,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x, y = next(iter(train_loader))\n",
"x, y = x.to(DEVICE), y.to(DEVICE)\n",
"\n",
"x_positive = combine_x_y(x, y)\n",
"x_positive"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5aiC-oPwOzUi"
},
"outputs": [],
"source": [
"\n",
"x_test, y_test = next(iter(test_loader))\n",
"x_test, y_test = x_test.to(DEVICE), y_test.to(DEVICE)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-03T22:13:09.490247Z",
"start_time": "2023-12-03T22:13:09.484813Z"
},
"id": "6N0OFm61AD2c"
},
"outputs": [],
"source": [
"\"\"\" Let's make a network with 2 layers of 100 neurons each. \"\"\"\n",
"input_dimension = x_positive[0].size(0)\n",
"network = FFNN([input_dimension, 100, 100])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-03T22:13:51.758600Z",
"start_time": "2023-12-03T22:13:09.752786Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "dLdX_4G3Bdgs",
"outputId": "07522836-929b-423d-ac15-763ae2c3bbd1"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training samples: 100%|██████████| 50000/50000 [27:00<00:00, 30.86it/s, Validation accuracy=0.897]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training samples: 100%|██████████| 50000/50000 [17:02:27<00:00, 1.23s/it, Validation accuracy=0.907] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training samples: 100%|██████████| 50000/50000 [42:35<00:00, 19.57it/s, Validation accuracy=0.912] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training samples: 100%|██████████| 50000/50000 [27:07<00:00, 30.72it/s, Validation accuracy=0.913]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 4\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training samples: 100%|██████████| 50000/50000 [27:39<00:00, 30.12it/s, Validation accuracy=0.914]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Training samples: 1%|▏ | 705/50000 [00:28<33:42, 24.37it/s, Validation accuracy=0.911]\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/Users/marc/Code/QPAI/Forward_Forward_MNIST_With_Explanations.ipynb Cell 13\u001b[0m line \u001b[0;36m5\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/marc/Code/QPAI/Forward_Forward_MNIST_With_Explanations.ipynb#X14sZmlsZQ%3D%3D?line=1'>2</a>\u001b[0m rnd \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mrandperm(y\u001b[39m.\u001b[39msize(\u001b[39m0\u001b[39m))\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/marc/Code/QPAI/Forward_Forward_MNIST_With_Explanations.ipynb#X14sZmlsZQ%3D%3D?line=2'>3</a>\u001b[0m x_negative \u001b[39m=\u001b[39m combine_x_y(x, y[rnd])\n\u001b[0;32m----> <a href='vscode-notebook-cell:/Users/marc/Code/QPAI/Forward_Forward_MNIST_With_Explanations.ipynb#X14sZmlsZQ%3D%3D?line=4'>5</a>\u001b[0m network\u001b[39m.\u001b[39;49mtrain(x_positive, x_negative, n_epochs\u001b[39m=\u001b[39;49m\u001b[39m10\u001b[39;49m, mode\u001b[39m=\u001b[39;49m\u001b[39m'\u001b[39;49m\u001b[39msample\u001b[39;49m\u001b[39m'\u001b[39;49m, validation_x\u001b[39m=\u001b[39;49mx_test, validation_y\u001b[39m=\u001b[39;49my_test)\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/marc/Code/QPAI/Forward_Forward_MNIST_With_Explanations.ipynb#X14sZmlsZQ%3D%3D?line=6'>7</a>\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m'\u001b[39m\u001b[39mAccuracy:\u001b[39m\u001b[39m'\u001b[39m, network\u001b[39m.\u001b[39mpredict(x)\u001b[39m.\u001b[39meq(y)\u001b[39m.\u001b[39mfloat()\u001b[39m.\u001b[39mmean()\u001b[39m.\u001b[39mitem())\n",
"\u001b[1;32m/Users/marc/Code/QPAI/Forward_Forward_MNIST_With_Explanations.ipynb Cell 13\u001b[0m line \u001b[0;36m3\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/marc/Code/QPAI/Forward_Forward_MNIST_With_Explanations.ipynb#X14sZmlsZQ%3D%3D?line=29'>30</a>\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_train_batch(x_positive, x_negative)\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/marc/Code/QPAI/Forward_Forward_MNIST_With_Explanations.ipynb#X14sZmlsZQ%3D%3D?line=30'>31</a>\u001b[0m \u001b[39melif\u001b[39;00m mode \u001b[39m==\u001b[39m \u001b[39m'\u001b[39m\u001b[39msample\u001b[39m\u001b[39m'\u001b[39m:\n\u001b[0;32m---> <a href='vscode-notebook-cell:/Users/marc/Code/QPAI/Forward_Forward_MNIST_With_Explanations.ipynb#X14sZmlsZQ%3D%3D?line=31'>32</a>\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_train_sample(x_positive, x_negative, validation_x, validation_y, n_epochs\u001b[39m=\u001b[39;49m\u001b[39m20\u001b[39;49m)\n",
"\u001b[1;32m/Users/marc/Code/QPAI/Forward_Forward_MNIST_With_Explanations.ipynb Cell 13\u001b[0m line \u001b[0;36m4\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/marc/Code/QPAI/Forward_Forward_MNIST_With_Explanations.ipynb#X14sZmlsZQ%3D%3D?line=44'>45</a>\u001b[0m h_positive_sample, h_negative_sample \u001b[39m=\u001b[39m x_positive[i]\u001b[39m.\u001b[39munsqueeze(\u001b[39m0\u001b[39m), x_negative[i]\u001b[39m.\u001b[39munsqueeze(\u001b[39m0\u001b[39m)\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/marc/Code/QPAI/Forward_Forward_MNIST_With_Explanations.ipynb#X14sZmlsZQ%3D%3D?line=45'>46</a>\u001b[0m \u001b[39mfor\u001b[39;00m j, layer \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlayers):\n\u001b[0;32m---> <a href='vscode-notebook-cell:/Users/marc/Code/QPAI/Forward_Forward_MNIST_With_Explanations.ipynb#X14sZmlsZQ%3D%3D?line=46'>47</a>\u001b[0m h_positive_sample, h_negative_sample \u001b[39m=\u001b[39m layer\u001b[39m.\u001b[39;49mtrain(h_positive_sample, h_negative_sample, verbose\u001b[39m=\u001b[39;49m\u001b[39mFalse\u001b[39;49;00m)\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/marc/Code/QPAI/Forward_Forward_MNIST_With_Explanations.ipynb#X14sZmlsZQ%3D%3D?line=48'>49</a>\u001b[0m \u001b[39mif\u001b[39;00m i \u001b[39m%\u001b[39m \u001b[39m10\u001b[39m \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m \u001b[39mand\u001b[39;00m validation_x \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m validation_y \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/marc/Code/QPAI/Forward_Forward_MNIST_With_Explanations.ipynb#X14sZmlsZQ%3D%3D?line=49'>50</a>\u001b[0m accuracy \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpredict(validation_x)\u001b[39m.\u001b[39meq(validation_y)\u001b[39m.\u001b[39mfloat()\u001b[39m.\u001b[39mmean()\u001b[39m.\u001b[39mitem()\n",
"\u001b[1;32m/Users/marc/Code/QPAI/Forward_Forward_MNIST_With_Explanations.ipynb Cell 13\u001b[0m line \u001b[0;36m6\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/marc/Code/QPAI/Forward_Forward_MNIST_With_Explanations.ipynb#X14sZmlsZQ%3D%3D?line=57'>58</a>\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/marc/Code/QPAI/Forward_Forward_MNIST_With_Explanations.ipynb#X14sZmlsZQ%3D%3D?line=58'>59</a>\u001b[0m \u001b[39mNotice that even though we use the gradient, this gradient is only locally computed within this layer.\u001b[39;00m\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/marc/Code/QPAI/Forward_Forward_MNIST_With_Explanations.ipynb#X14sZmlsZQ%3D%3D?line=59'>60</a>\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/marc/Code/QPAI/Forward_Forward_MNIST_With_Explanations.ipynb#X14sZmlsZQ%3D%3D?line=60'>61</a>\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39moptimizer\u001b[39m.\u001b[39mzero_grad()\n\u001b[0;32m---> <a href='vscode-notebook-cell:/Users/marc/Code/QPAI/Forward_Forward_MNIST_With_Explanations.ipynb#X14sZmlsZQ%3D%3D?line=61'>62</a>\u001b[0m loss\u001b[39m.\u001b[39;49mbackward()\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/marc/Code/QPAI/Forward_Forward_MNIST_With_Explanations.ipynb#X14sZmlsZQ%3D%3D?line=62'>63</a>\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39moptimizer\u001b[39m.\u001b[39mstep()\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/marc/Code/QPAI/Forward_Forward_MNIST_With_Explanations.ipynb#X14sZmlsZQ%3D%3D?line=64'>65</a>\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlosses\u001b[39m.\u001b[39mappend(loss\u001b[39m.\u001b[39mitem())\n",
"File \u001b[0;32m~/mambaforge/envs/torch-metal/lib/python3.10/site-packages/torch/_tensor.py:522\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 512\u001b[0m \u001b[39mif\u001b[39;00m has_torch_function_unary(\u001b[39mself\u001b[39m):\n\u001b[1;32m 513\u001b[0m \u001b[39mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 514\u001b[0m Tensor\u001b[39m.\u001b[39mbackward,\n\u001b[1;32m 515\u001b[0m (\u001b[39mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 520\u001b[0m inputs\u001b[39m=\u001b[39minputs,\n\u001b[1;32m 521\u001b[0m )\n\u001b[0;32m--> 522\u001b[0m torch\u001b[39m.\u001b[39;49mautograd\u001b[39m.\u001b[39;49mbackward(\n\u001b[1;32m 523\u001b[0m \u001b[39mself\u001b[39;49m, gradient, retain_graph, create_graph, inputs\u001b[39m=\u001b[39;49minputs\n\u001b[1;32m 524\u001b[0m )\n",
"File \u001b[0;32m~/mambaforge/envs/torch-metal/lib/python3.10/site-packages/torch/autograd/__init__.py:266\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 261\u001b[0m retain_graph \u001b[39m=\u001b[39m create_graph\n\u001b[1;32m 263\u001b[0m \u001b[39m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[1;32m 264\u001b[0m \u001b[39m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 265\u001b[0m \u001b[39m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 266\u001b[0m Variable\u001b[39m.\u001b[39;49m_execution_engine\u001b[39m.\u001b[39;49mrun_backward( \u001b[39m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 267\u001b[0m tensors,\n\u001b[1;32m 268\u001b[0m grad_tensors_,\n\u001b[1;32m 269\u001b[0m retain_graph,\n\u001b[1;32m 270\u001b[0m create_graph,\n\u001b[1;32m 271\u001b[0m inputs,\n\u001b[1;32m 272\u001b[0m allow_unreachable\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m 273\u001b[0m accumulate_grad\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m 274\u001b[0m )\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"\"\"\" We can generate simple negative data by simply combining the images with random labels. \"\"\"\n",
"rnd = torch.randperm(y.size(0))\n",
"x_negative = combine_x_y(x, y[rnd])\n",
"\n",
"network.train(x_positive, x_negative, n_epochs=10, mode='sample', validation_x=x_test, validation_y=y_test)\n",
"\n",
"print('Accuracy:', network.predict(x).eq(y).float().mean().item())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VJV8sE07OzUi",
"outputId": "da6253d1-13ce-412d-e52e-1ab280945025"
},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x3022ef4c0>]"
]
},
"execution_count": 88,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(network._history['accuracy'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-03T22:13:54.915422Z",
"start_time": "2023-12-03T22:13:54.154311Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "_UQYITJGCzLH",
"outputId": "da4566fb-e37a-4a35-ac6b-7fba06926113"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy: 0.9107999801635742\n"
]
}
],
"source": [
"x_test, y_test = next(iter(test_loader))\n",
"x_test, y_test = x_test.to(DEVICE), y_test.to(DEVICE)\n",
"\n",
"print('Test accuracy:', network.predict(x_test).eq(y_test).float().mean().item())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-03T22:14:00.468042Z",
"start_time": "2023-12-03T22:13:59.979175Z"
},
"colab": {
"base_uri": "https://localhost:8080/",
"height": 573
},
"id": "KXsNdJfYC9YK",
"outputId": "c3e43afe-2e11-456d-c2f8-bdd2cc363cb1"
},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"def plot_loss(loss: List[float], title):\n",
" plt.plot(range(len(loss)), loss)\n",
" plt.title(title)\n",
" plt.xlabel('Training epoch')\n",
" plt.ylabel('Loss')\n",
"\n",
"\n",
"losses = network.get_losses()\n",
"plot_loss(losses[0], 'Layer 0 loss')\n",
"plt.show()\n",
"plot_loss(losses[1], 'Layer 1 loss')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NLZcP3DgOzUj"
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"provenance": [],
"include_colab_link": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment