Skip to content

Instantly share code, notes, and snippets.

@fatcatZF
Forked from mattsgithub/ssvm.ipynb
Created April 25, 2022 09:23
Show Gist options
  • Save fatcatZF/22a2d6482ae2da6f1bece5f603ff3c24 to your computer and use it in GitHub Desktop.
Save fatcatZF/22a2d6482ae2da6f1bece5f603ff3c24 to your computer and use it in GitHub Desktop.
Learning how to build a structural svm. WIP
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Introduction\n",
"\n",
"I realize I'm late to the structural SVM party (like 20 years late). But recently, I've been doing a deep dive into teaching myself more about structured prediction. Charles Martin got me into this. We are revisiting an old paper on Multivariate loss optimization together and wanted to see how it fits into the Pystruct package.\n",
"\n",
"I realized I needed a stronger background on structured learning in general. So the point of this article is to review an older paper on the topic:\n",
"\n",
"**Large Margin Methods for Structured and Interdependent Output Variables**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# The Paper\n",
"\n",
"It was published in JMLR (Journal of Machine Learning Research) in 2005 (almost 20 years ago!). Tsochantaridis, Joachims, Hofmann, and Altun figured out how to generalize SVMs to make structured predictions.\n",
"\n",
"I attended a workshop with Joachims in RecSys 2019 in Vancoouver. As I was still a n00b (and still am to this day!). I didn't know much about the guy. Only after the fact, did I realize on how much of a big player he has been in ML (SVMs, search, structured prediction, etc.)\n",
"\n",
"\n",
"### What the hell is structured prediction?\n",
"What is structured prediction you might ask? Great question. It's predicting things like trees, graphs, etc. That is, structured objects. \n",
"\n",
"Can't we just assign each structured object a label and call it a day? This doesn't scale or even work in some cases. The core idea is to produce a function F such that:\n",
"\n",
"$f(x | w) = \\text{argmax}_y F(x, y | w)$\n",
"\n",
"This means we need to iterate over all possible outputs and take the argmax. What is $ F(x, y | w)$? In the paper, it's restricted to be:\n",
"\n",
"$F = w \\cdot \\Psi(x, y)$\n",
"\n",
"$\\Psi(x, y)$ is outputs a vector that has the same size as $w$. Where does this come from? In the paper they write\n",
"\n",
"\"The specific form of $\\Psi$ of depends on the nature of the of the problem\"\n",
"\n",
"They do provide a concrete example of one. Take the problem of decomposing a sentence into a parse tree. Each node in the parse tree maps to a grammar rule and a score. I'm not sure exactly how this score gets computed (haven't worked a ton on this problem), but my guess is that it's some type of quality score. How well it thinks the grammar rule applies. \n",
"\n",
"The idea is to sum the weights for each grammar rule. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Loss\n",
"\n",
"They consider a loss function $\\Delta(y, f(y))$ that outputs a real non-negative number. The loss is 0 if label and predictions agree. Otherwise, greater than 0.\n",
"\n",
"Here, $f$ is the model we learn from data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Risk\n",
"\n",
"In the supervised settings, we are given (x,y) examples. Assuming these were drawn at random, you can think of these as being drawn from some P(x,y) distribution that is unknown to us. And we want to minimize the loss associated with it.\n",
"\n",
"We want to make a prediction over the entire space, so in theory, we want to minimize:\n",
"\n",
"$R^{\\Delta}_{P}(f) = \\int \\Delta(y, f(y)) dP$\n",
"\n",
"In practice, this is not known of course, so we settle for the empirical form:\n",
"\n",
"$R^{\\Delta}_{S}(f) = \\frac{1}{n} \\sum_i \\Delta(y_i, f(y_i))$\n",
"\n",
"The $S$ means it's drawn from the sample space we are working with"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Margin\n",
"\n",
"To learn a structured model, they choose to learn with SVMs. It seems the core reason being is that they are able to devise an algorithm that avoids having to predict over the entire $y$ space.\n",
"\n",
"They work their idea with increasingly harder problems in the order:\n",
"\n",
"* Hard Margin (case when all examples are linearly separable)\n",
"* Soft Margin\n",
"* Lost-Sensitive SVM (generalizes both)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Hard Margin\n",
"\n",
"In this case, the solution is linearly separable so zero loss can be reached. This means for a given example, the loss of the correct prediction will be smaller than all other predictions.\n",
"\n",
"When the loss is minimized, $w \\cdot \\Psi(x, y)$ is a maximum under the correct prediction. Why? It measures the most probable pairing. Suppose the mulitlabel problem has three labels (so possible outcomes can be [0, 0, 1], [1, 0, 1]). Let's look at all 8:\n",
"\n",
"Let's suppose the correct output is [1, 0, 0]. Then all the following will be true when we have reached the optimal solution (i.e., linear separation):\n",
"\n",
"(1) $w \\cdot \\Psi(x, y = \\text{100}) >= w \\cdot \\Psi(x, y = \\text{000})$\n",
"\n",
"(2) $w \\cdot \\Psi(x, y = \\text{100}) >= w \\cdot \\Psi(x, y = \\text{010})$\n",
"\n",
"(3) $w \\cdot \\Psi(x, y = \\text{100}) >= w \\cdot \\Psi(x, y = \\text{001})$\n",
"\n",
"(4) $w \\cdot \\Psi(x, y = \\text{100}) >= w \\cdot \\Psi(x, y = \\text{110})$\n",
"\n",
"(5) $w \\cdot \\Psi(x, y = \\text{100}) >= w \\cdot \\Psi(x, y = \\text{011})$\n",
"\n",
"(6) $w \\cdot \\Psi(x, y = \\text{100}) >= w \\cdot \\Psi(x, y = \\text{101})$\n",
"\n",
"(7) $w \\cdot \\Psi(x, y = \\text{100}) >= w \\cdot \\Psi(x, y = \\text{111})$"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So there are $|y| - 1$ constraints (we drop the constraint for the correct prediction--a value being equal to itself is not a a constraint). But we have $N$ examples, that means $N(|y| - 1)$ in total. We write this compactly as:\n",
"\n",
"For each example \"i\", let the index \"k\" denote the correct answer. And let \"i\" denote one of the possible values of y. But where \"i\" can never be \"k\" (i.e., itself) then:\n",
"\n",
"$w \\cdot \\Psi(x, y = y_{ik}) >= w \\cdot \\Psi(x, y = y_{ij})$\n",
"\n",
"or\n",
"\n",
"$w \\cdot \\big( \\Psi(x, y = y_{ik}) - \\Psi(x, y = y_{ij}) \\big) >= 0$\n",
"\n",
"or\n",
"\n",
"$ w \\cdot \\delta \\Psi_i(y) >= 0$"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To compare the efficay of this approach. I want to try a few experiments. All will use generated data from the following dataset:"
]
},
{
"cell_type": "code",
"execution_count": 215,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(670, 8) (670, 3) (330, 3)\n"
]
}
],
"source": [
"from sklearn.datasets import make_multilabel_classification\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"n_features = 8\n",
"n_classes = 3\n",
"\n",
"X, y = make_multilabel_classification(n_samples=1000, n_features=n_features, n_classes=n_classes,\n",
" n_labels=2,\n",
" length=50, allow_unlabeled=False, sparse=False,\n",
" return_indicator='dense',\n",
" return_distributions=False, random_state=23)\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33,\n",
" random_state=42)\n",
"\n",
"print(X_train.shape, y_train.shape, y_test.shape)\n",
"\n",
"# Friendlier names\n",
"n_train_examples = X_train.shape[0]\n",
"n_test_examples = X_test.shape[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We need a feature mapping that outputs a vector given x and y. For that, let's build an autoencoder"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We need a function that takes as input x and y and outputs a vector. This is called the \"feature map\" in the paper. How do we make this? The do describe a way in section 4 (Specific Problems and Special Cases). I'm going to choose to use a very simple autoencoder to illustrate the idea instead"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Feature Map\n",
"Let's now build $\\Psi(x, y)$"
]
},
{
"cell_type": "code",
"execution_count": 216,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import itertools\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader"
]
},
{
"cell_type": "code",
"execution_count": 217,
"metadata": {},
"outputs": [],
"source": [
"class Coder(nn.Module):\n",
" def __init__(self, n_input, n_output):\n",
" \n",
" super(Coder, self).__init__()\n",
" self.dense = nn.Linear(n_input, n_output)\n",
" \n",
" def forward(self, x):\n",
" x = self.dense(x)\n",
" x = F.relu(x)\n",
" return x\n",
"\n",
"class AutoEncoder(nn.Module):\n",
" def __init__(self, n_input, n_hidden):\n",
" super(AutoEncoder, self).__init__()\n",
" \n",
" self.encoder = Coder(n_input, n_hidden)\n",
" self.decoder = Coder(n_hidden, n_input)\n",
" \n",
" def forward(self, x):\n",
" x = self.encoder(x)\n",
" x = self.decoder(x)\n",
" return x\n",
" \n",
" def get_feature_map(self, x):\n",
" return self.encoder(x)"
]
},
{
"cell_type": "code",
"execution_count": 218,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"class Dataset(torch.utils.data.Dataset):\n",
" def __init__(self, X, y):\n",
" self.X = X\n",
" self.y = y\n",
"\n",
" def __len__(self):\n",
" return self.X.shape[0]\n",
"\n",
" def __getitem__(self, index):\n",
" X = self.X[index]\n",
" y = self.y[index]\n",
" return X, y\n",
" \n",
"fm_train = np.concatenate((X_train, y_train), axis=1).astype('float32')\n",
"n_input = fm_train.shape[1]\n",
"reduction_fraction = 0.3\n",
"n_hidden = int(np.round(n_input * reduction_fraction))\n",
"\n",
"auto_encoder = AutoEncoder(n_input=n_input, n_hidden=n_hidden)\n",
"\n",
"loss = nn.MSELoss()\n",
"optimizer = optim.SGD(auto_encoder.parameters(), lr=0.001)\n",
"n_epochs = 200\n",
"\n",
"train_dataset = Dataset(fm_train, fm_train)\n",
"train_dataloader = DataLoader(train_dataset,\n",
" batch_size=8,\n",
" shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 219,
"metadata": {},
"outputs": [],
"source": [
"m = 0\n",
"n_training_examples = []\n",
"losses = []\n",
"\n",
"for epoch in range(1, n_epochs + 1):\n",
" for i, data in enumerate(train_dataloader, 0):\n",
" X_train_, y_train_ = data\n",
" \n",
" optimizer.zero_grad()\n",
" \n",
" y_hat = auto_encoder(X_train_)\n",
" out = loss(y_hat, y_train_)\n",
" out.backward()\n",
" optimizer.step()\n",
" \n",
" # Update number of examples trained on\n",
" m += X_train_.shape[0]\n",
" n_training_examples.append(m)\n",
" losses.append(out / X_train_.shape[0])"
]
},
{
"cell_type": "code",
"execution_count": 220,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x12cb3f8e0>]"
]
},
"execution_count": 220,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXsAAAD4CAYAAAANbUbJAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy86wFpkAAAACXBIWXMAAAsTAAALEwEAmpwYAAArT0lEQVR4nO3deXwU5f0H8M+TBMJ9R26IoHIqhxFBAREREKhWW61aby1Wbev1K0IVbwWxHrVaxaNeRRFRsQUFueSQQ8MVwh3CfSWEI0BIQpLn98fObmZ3Z3Znd2d2Znc+79eLF7uzszPfzM5855nneeYZIaUEEREltxS7AyAiIusx2RMRuQCTPRGRCzDZExG5AJM9EZELpFmx0GbNmsnMzEwrFk1ElJRWrVp1WEqZYdXyLUn2mZmZyM7OtmLRRERJSQixy8rlsxqHiMgFmOyJiFyAyZ6IyAWY7ImIXIDJnojIBZjsiYhcgMmeiMgFHJfsi06W4fv1B+wOg4goqTgu2d/1cTbum7IaR0+V2x0KEVHScFSy/+inHVi35xgAoKKKD1UhIjKLo5L90//baHcIRERJyVHJnoiIrMFkT0TkAkz2REQuwGRPROQCjk32EuyNQ0RkFscmeyIiMg+TPRGRCzDZExG5AJM9EZELMNkTEbkAkz0RkQsw2RMRuYBzkz272RMRmca5yZ6IiEzDZE9E5AJM9kRELsBkT0TkAkz2REQuwGRPROQCjk327HlJRGQexyZ7IiIyD5M9EZELMNkTEbkAkz0RkQsw2RMRuQCTPRGRCzDZExG5AJM9EZELMNkTEbmA4WQvhEgVQqwRQsy0MiAiIjJfJCX7BwFssioQIiKyjqFkL4RoA2AkgPetDYeIiKxgtGT/OoAxAKr0ZhBCjBZCZAshsgsLC82IjYiITBI22QshRgEokFKuCjWflPJdKWWWlDIrIyPDtACJiCh2Rkr2lwK4WgixE8BUAIOFEP+xNCoAkmMcExGZJmyyl1KOk1K2kVJmArgRwAIp5S2WR0ZERKZhP3siIhdIi2RmKeWPAH60JBIiIrIMS/ZERC7AZE9E5AJM9kRELsBkT0TkAo5N9hLsaE9EZBbHJnsiIjIPkz0RkQsw2RMRuQCTPRGRCzDZExG5gGOT/bZDJ+0OgYgoaTg22d/275/tDoGIKGk4NtkDwJ4jJcjdd9zuMEzzw4aD+GT5TrvDICIXimjUy3gbMGkhAGDnxJE2R2KO0Z96HvZ1W79MewMhItdxdMmeiIjMwWRPROQCTPZERC7AZE9E5AJM9kRELsBkT0TkAkz2REQuwGRvktIzlSirqLQ7DCIiTUz2Juk8fjYunbjQ7jCIiDQx2Zvo8Mkyu0MgItLEZB+gpLwC2ws54iYRJRcm+wD3froKV7yyCJVVfOA5ESUPJvsAP+UdtjsEIiLTMdkTEbkAkz0RkQsw2QdgTT0RJSMmex3C7gCIiEzEZE9E5AJM9kRELsBkT0TkAo5+4LjX4q2FqF8rDUdOleOKLs0tXZdkCy0RJaGESPb3T1mNk2UVAICdE0faHA1RsMyxs3B7v/Z45prudodCpClsNY4QopYQ4mchxDohxAYhxDPxCEzNm+iJnOzj5bvsDoFIl5E6+zIAg6WUPQD0BDBcCNHX0qhC4MiSRESRC5vspYd3GMgayj/barb7v7TArlUTESUsQ71xhBCpQoi1AAoAzJVSrrQ0qhBKz1TFZT3bC0/i9XlbIdliS0RJwFCyl1JWSil7AmgDoI8QIqgVSggxWgiRLYTILiwsNDnM+LvpvZV4fd42HCs5Y3copus3YT4ufnGe3WE4WubYWXj4i7V2h0Fkmoj62UspjwFYCGC4xmfvSimzpJRZGRkZJoVnn/Ikfp7sgeOlOFTMto9wvlmzz+4QiExjpDdOhhCikfK6NoArAWy2OC4iIjKRkX72LQF8LIRIhefkME1KOdPasIiIyExhk72UMgdArzjEQkREFuHYOERELsBkHwY7XrrD3qMlyBw7i88gpqTFZE8EIHvnUQDAtOw9NkdCZA0m+zD4xCoiSgZM9kRkmz1HSnC6PHnvaXESJnsiFY6OEV8DJi3EnR/9bHcYrsBkr4PHvLsI1tfZZkX+EbtDcAUm+zCY9IkoGSRVsj9w/DTOVJozKiYLekSUTJIm2Z8sq0C/CQswfkau3aFQAgt1JffRTzuQOXYWh72mhJQ0yb5EeXTh/M0FpiwvGQ/npdsO41hJud1hJKxnZ24EAFQl486RoE6VVaDMphFqj5ecwdFTiXM8JU2yt0qyVOeUlFfglg9W4s6Pfgk775nKKgz++4+Yt/FQHCIjil63p+bg6n/+ZMu6ezz7A3o9N9eWdUeDyd4lKpTiaN6hk2HmBIpOliP/8Ck8PmO91WE5TrKc3N1ky6ETdoeQEJjsKWJ23AizZvdRvDp3q+XrYQ0NJaukSfZWHaQ8+IMNmLQQd38cvjrITNf+axnemL8truvUwwba5HTkVDkmfLcJFSb16HOahEz2VVUSv39/BRZvDX7WrVmX4cl2OR9Nfgr1nWXbi6IPJoRVu45i+qq9liw7VoJ3Xlmm9Eyl7SfRp/+7AZMX52PeJnM6eThNQib7E6UV+CmvCH/6bLXdoSQlO3Pab95ehv/7cp19ASQgKSX2Hzvte3/Sxh4q0eo8fjb+YdKV26crdmHPkZKIv1de4SnR233SsUpCJvvHvsqxfB3J9nOzUGqMHQd6wYlSHDxeGvX3P162E5dMXIAN+48DALo/NQednphtVnhxM8OEB7wXl57B+Bm5uPn9FSZElFwSMtnP3nAQAFBcWoEjSj9Xo8dozt5jeHDqGlQZ7CzNHGmtqiqJguLoE51Z7Kyi6fPCfPSdMD/q76/c4RlbZldR5KXZZCOV6vZjJWfw1sK8iEr4MumKeP4SMtmr9Q7o5xrumL3301X4du1+HDQxwby/JN+0ZUXqwPHTeOCz1Rg/IxcnSs/ozhdVnX0McRn11sI89HlxflSX3XZJ7pSQHE6UVuDlOVtw+785oqZXwid7qxk5sJ+ftcnyOPQ8+e0GzMo5gE9X7MJrcw3UeRoowMazjLt4m6eR/UAM1Rh6Xpq9Gcu2u+cxg4En9ESrt7fC6TPGt4FI8ut41yb7cEnc6T970cmyoPrlyip7u4xJKbG9MPxNW/Hy9o/bcfN7K01fbrz2jZX5Rfh2bfh6bO/VbGA1RJK2M4YWw4/DahyyxP5jp7F+7/Govru7qAQXPj8P7y3Jj/yAjmD+whNlePvH7Ybnf29JPq54ZVHUfxf5+927K/Dg1LVh5/OWSO1M7ku3Hca0X5Lj+b16VcGnyiqwO4HbRZIi2a/efRQ5e48FTX9/ST7u1hkLpvBEGUrKKyyOTN8lExfgV28ujeq7K/I9fdx/3BJ8n0E0th06gbk64+C8NHuz4eWs3nUMALD3qOeA2HfsNL74ZXfI79iRoHYcPmWoxKxn9oaDvm56oVRWSVTGY9Q0neQUbZvzrqJTeHDqGkN/o9ctH6zEmDj0krPCws0F+PfSHb73evvkze+twMCXF8YpKvMlRbK/7l/LMPrTVQD8692en7UpaBRM76e/fusnXP/Oct1lFpdadyII1ZAaTnHpmYCDqnrPNNKj5ERZ8N915WuL8YdPsqOOSc9N767AY1+txyllnVJKfLf+gGYSiWdnmGGvLTZUYtbzp8/W4JW5W8LON3DSQnR7Kr5dIPep+ttHa8z0HHy7dj9W7Tqq+fmE7zYhe2fyPF3qzo9+8Y1oqqWySmJ3UQnWJfgVa1Ike7WDxaWGe8ds2F+MzLGzsKvolMVR+Xvmf9U71nfrD0SU/EvK7G10i6QfetHJMs93lPc/bi3E/VNW4/V51o9xE0p5iNvhQ/116vPR/mPhG5T3HTuN0jP66zpRegZjpsd+A5n3hqpVu47i0okLYl5eOJMX5+O3IQpKZiurqDTc2BxYaDhwvDTiE2DgMv4xf1tQif7nHYl3sku6ZA9E3jtGrwoD8D/Aj5wqR+6+2M/uxaerk/v9U1bj0WnRHfB7jvrXH+4+UoLiCE4cVVUSB47HXhIMpJcwj5d4YjOj9Kklv/Bk1FdNkV5YmHHz1eRF+ZiWHfvQEGt2HwMAfLRsp9/0aHuXrFOqROdvMn+I66oqiRe/26S5D+wsKtGs9896bh66PjnH0PKXbgvufRXrCdBbbarmLcgkkqRM9maSAKZl78HvJi/HNW8txah/hq5n37i/OOKEk7vvODYfLDYYT3WS2XPktF/94oLNBfj1W8bH9n570Xb0m6BxIITIEaFyXLiqGO/n6mpsM2u0B7+yCDdMDr5zUkqJBz5bjVH/XOL32Mp8nZ5DRk6YZsRtR++P5duLkDl2FgpP6Ccr79XI0rzw3VYjbfdau/cY3l2cj4emrtH8XKve/0RZheG2j/unRD+ESrL3XkraZH/L++Z1uRszPQcrdxzBniPhS6Qj3liCu8I8IOSHgCuJ/cdLMfz1JVHFFrh/5hcar5LSKgVpmfbLHiyPYOCz+6esDnODV/BR9d36A4aXH8qmA8EnzU+W78KsnAPI3VeMzQeqxz4f/Mqi4IQvgUMW9Pm3g9bJ9wOlIXLNbu36+EjMzj2Irk/OMdT7alfRKWSOneW7iq4Ik7zv/ugX3BbBDVFWXKEmm6RN9kZKJVb5ZWfsB5JR8RjLZcxXObjpPU+JOXBtM9bs8xuEy0tdp71gcwEyx87y3TilFfGHP+00KdpgT/13g+5nBUoJV++qRPexczqb/d5Ps2Ma6z+/8KRvjBu1eRsPIa8gtnsYDp8s89V9n6mUGPtVDgpORH9i894Qt1ajJ1ygtXs88/x37X5Dy56/uUBzVFs9mleoEUr28aOSNtlb6fkQLfdWq6iMMrlbcE4or6jCQ1+sxQ2TQzfWfZntqYddr7R32DmqoF7ViTekWesP4MrXFvumv7Uwz/daXRot1bkzc86GQ/hu/QH8Z8WuqOIb/MoijHxjqV91EwDc80k2hry6KKplemU9Pw9LlKu5uRsPYuove/B0iBNhuN5d3jYYdRuU2ts/bscgB3dVrKySeOrbXN977z6Quy90leony3fivhiqi+zCZI/IB8F6X9Un12w5e49hVo5+lcaE760fmuHoKWNVMN7EWVAcXP9bdKoMpwJKuN6/S53ro0n87y/JD2o8NNpOcvWb5jyvVN2lNzCWyYu344kZ1UnE6KB7ahO+M35/g55IC6rHSsqx43B1NaAI+CzwoR6z1nt/T+2/76XZm7EzipuQotle0Vi/7zg+Xl59Uj6hdLd+U3WCzxw7K6jnzZPf6p8gncx1yV4rsTvp6u3qN3/CAyHG6V+y1b96amGYG6v2HClBQXEpDmlcruud44a9vlj7gwDzNuo/5OGl2fr90KtiLNk/P2sT7v44G9+r6vnfWWT8Tl81My4yAqvtjpb4n3j+9WMeIjUzx1h1h5mGvLoYl//9R997ITyN1aVnKtHz2bkYM137pqlw9e+RUidbrSqt/MKTplwdhlqGlBKfrQx9Q2CicV2y12J2XZ1T6v72Hi3BgEkL0efF+Rj6WvgE3uvZH/DjltBP6VEfHqFOSuosGniCnbPhELaa8JDo+6asxk6lJKoudGaOnRXxsmL5zQKrhgJ7uqxUlQyllHh5zmZs2H885IlGr5ooEmUVVWHr5NUxHNboTnjB0z/gun8tAwDM0LnrWOuhI4HDVnv3ASNJenbuQd/rkW/49377x7xtGPzKInxjwtj3oazIP4K/fbPe0nXEG5O9BQL35/KKKkxZGV0dbtCyI5g3VPc6IPjRgkdLzmBSiBJ5RFTZs0Tjrl1vf2r1iSCvwP8EsPdoSdgud6VxGNkx1A09R0+V46cIOgOUV1bhrYXbca2SQPWYcQf3XR/9gj4vhB4n//vcg747nAN5f5qNSg8nvV9CK3/f+oF/T5q3Fvhf3UR7bn1NuSEvXL16rE6fsW8oFasw2VtEfaPWmwvz8Pg3uSHmNs7uR6bprX7JtkJ8ryqRqQ/mbI3b7r2LUf89Q16tvvrYc6QE/V9aiNfmGrvbNto+696E9qfPtPt9/7i1EJ2emI11Sm+SQHd8+HPYxOMdI2dX0SmsyD/iDThqVVUSX6/eiz1HSkI22q40eJfn+G9z8d7i4LvOA2/KimTXU/fQemLGemwx4UrOL5YoNuCa3UexLO8wXpi1EVIm+xiXwdLsDiBe8gpOokOzupqfeXfp4yXRj1kT6A+fZOPGi9ringEd9LvvBTD7qfZWPH1Jr/rlMZ36XD0fLN2BD5buQO92jTQ/91Y/vLkwD1ed30J3Oev3HsfMdQeCkqfRQbzCJTBvd0etcWJKyisMjZeybHsRrnt7md8JI9SQDeFMy96DsV+vR3paCsoiGKxMTb1vfLNmH76WxqpFFoXpDpm77zi6t27oN+0/K6rrvr2bu6S8Eofi/IQy9dXUjX3aBd2TkuzpP2zJXgjRVgixUAixUQixQQjxYDwCM9Pmg8UY8uoiDJgUuhvYaxGM2fKc0v0yVN3q1F/24IEIumj9d11wo9ynK3Yhc+wsX910JKxoOpg0x1g1T6znGXUSVtfbBl7Z/HV6jl+Dntckg6N1at2EZVQkvTL0rgyikaN0YY020QP++4beCU/rNwz35KejJaELNt51bT54Ahe/GP2jGGO9wJUSQQ+29111mbB8JzJSjVMB4FEpZVcAfQE8IIToam1Y5vJeUu47dhrHNHZGbymnIoKHf3ywdAe+WrUXX64KP7aJ0cSnLo3e+eHP+GzlboxXuvANUvWSsNLGKJPfiYA65lhPNJEeazsCToZGH6Ly7MyNUY9RPt3Abx+K0f3i1g/87wZP5F4iRv5mI4+o1LoD2K4hy+N9hRKtsNU4UsoDAA4or08IITYBaA3AvjuLIvSPedW9BbTa+3xP+tH4LFQd+aMBJQMtEtJwKeHzn6sP4oVbCsN2q7SD3vYIHDr5uM6NNkbpNRrqCRyCIhLhSqOAvT2slhgc1iKUaK4s7PqTtYbhDrTnaPBd20YHS/MIM1zDx8aH/N5VVILCE2V4YkYuPv9DX9SumRpBHPETUZ29ECITQC8AQQPPCCFGAxgNAO3atYsqmF/1aIX/aVRlxEpdr6r1TMpwvVZisfXQSRw0ONaKkfpfoyXe37y9THc88lgYTTzbDY7Ro/f33PGh9vhCdl1eW7HeSB4O4hXJEAJeo/65JKgR+WsDXRejafPxbSedr5q1HY8YbAeLh+mr9iB3XzE2HijGmj1HcUnHZnaHpMlwbxwhRD0AXwF4SEoZdK0vpXxXSpklpczKyMiIKpjaNezpHBSqFGpGI6eZD0IxerBYkeif/Fa7R1EsjY2Rel2jT7eWSH6373LNGYQtUu8tifxObPWduUZZ3U1RS2C1ntfBGKo8XpjlX5mwcHPoe0LiZVr23qirP+PJUHYVQtSAJ9FPkVJ+bVUwv7+4vVWLNiQR2mTsrE74RHVruV3eMJjsIzF5kbGH3VjNyEBnuw3UZ5slmn2tpLxSd/jiWAWeHO8MM7psKJZdITo4iRjpjSMAfABgk5TyVSuDqZFqT8k+VBIz2qsjHgpOlBqqR734xXmWx+J0Zp8T43GSjXWgM7N5H4oSiTcXbsMMgyNbUnwZya6XArgVwGAhxFrl3wiL43KMqRpPzrHL/E0FQYOLaTmkMTCZU0WTUMi57KgyMiKwY0HIoT6SlJHeOEvhrLHCLOP0vrXzNzmjjtKNth6KbSx5charfk8npxAOl6Aw+kBjSgzzTW68U3eLJUpErhkuIZwbJq8w9S5HK8yz4AHQRGQeJ9cOOKpkb2dPE6cneiKKXqHG8M1u46hkT0RkhXBDPbsBkz0RkQs4Ktk75QlPRETJxlHJnogokTl5THxHJfuOGfXsDoGIKCk5KtnbNVwCEZEZTHuGswWYXYmITLJ+X/hhyu3CZE9E5AJM9kRELsBkT0TkAkz2REQuwGRPROQCjkv2z/+6u90hEBElHccl+99f3M7uEIiIko7jkr3gADlERKZzXLInIiLzOTrZz3vkMrtDICJKCo5O9u2a1LE7BCKipODoZE9EROZgsicicgFHJ3t2zCEiMoejkz0REZnDkcm+c4v6AAB1wf7egR3sCYaIKAmk2R2Alml/7IeC4lK/aeNGdMHkxfk2RURElNgcmewb1KqBBrVqoLLKuQ/vJSJKJI6sxiEiInM5OtmzMw4RkTkcneyJiMgcjk727GdPRGQORyd7IiIyh6OTPce2JyIyh6OTPRERmYPJnojIBcImeyHEv4UQBUKI3HgERERE5jNSsv8IwHCL49B1xyWZ+PKP/exaPRFRUgg7XIKUcrEQIjMOsWh6+upudq2aiChpmFZnL4QYLYTIFkJkFxYWmrVYIiIygWnJXkr5rpQyS0qZlZGRYdZiiYjIBOyNQ0TkAkz2REQuYKTr5ecAlgPoJITYK4S42/qwiIjITEZ649wUj0CMGHheBhZvZeMvEVGkEqoa58M7LsKW523r8k9ElLASKtmnpgikp6X6vSciovASKtkTEVF0mOyJiFwgoZO9lNLuEIiIEkJCJvvnf90dfxvR2e4wKAE1q1fT7hCIbBG266UT3dK3PQBASmDC95ttjoYSCS8GrdW8QToOFZfFZV0dM+pie+Ep3/v0tBSUVVTFZd2JKCFL9l73XtYR+S+OwOyHBuC+QR0BhC+53aqcKLRk1E83NT5yjidGdgEA1ElPxSNXnqc5z9TRfeMZUtKZcN35WPh/g+K2vi/u9R/6vFGdGnFbdyJK6GQPACkpAp1bNMBjwztj6WOXY/4jg5Dz9FB8ds/FmvNfek7TiNexbOzgWMMMa+fEkRHNP+6q6KuxnhzVNervBkqLc/fXRnVq4KruLSL6zo0XtcXIC1oCADKb1sWfB5+jOV/fDpHvGwBQM83/MOqnLKdfh6ZY+tjl6NyiflTLNeKb+y/xe9+6Ue2ol7V6/JV+7+ul+1/4q/fRnx+/Iuj7N/Vphzo1Q1cWXHZe6EESb7yobbgwfZrV8y+cCcS2L75zy4V+71s2rKU7b1b7xvj2gUtjWl+8JXyyV2vTuA4a1qmBBrVqhOiDXz29U3P/gzDwx/ZqFcMBpOfqHq18r2NNmBe2b6z72cw/9/e9vrVve6x6YghG9WgZ0/rUljx2OSbfeiFevaFHyPn6ZDYJmta0bk2MHtgBANCigf6BFej2SzINz9usXk1M/M0FaNmwNt699UK8eXNvCCFQM1V71//wjovw0Z0XGV4+ACx49DLf650TR2LM8E4AgEGdMtCmcR0Iof37vnjt+Xh8RJeI1hWoVzv/3z6SB/3MUCWrAec2Q5O6/lfFobZD4Pa7+eJ2mvO1aFAL658eioeHeK6merRpGDKmNo1r48aL2uKvwzqFnE/LNT1bhZ9JxyUdm2J49xZ+J7SqEHV+mc3qokfbRrin/9l+0yPZj+MtqZJ9KNXHW/UPOOfhgb7X40d1DZk0A3XMqKs53ZtcWzSohf7nNNP9fuM6NfD5H4xVG9SvFVxaalC7+pLVyKmia8sGeO7X3dG0XjrOql8Ly8d5rlb0kl6gkRe0xHnN6wVNb9mwNoZ1a4HrercxtBy1Oy7JxHCllK6VD/9yxbma30vRSZ6BGtWpgewnqkurQ7u1QENlu2WPH6L5ncs7n4VBnc4Kmn7vwA745K4+AIBre7X2+yzwAO/VrjHmPDQQfxjQQYnXULiaruh8FiZcd77ftIz66Zj/6GVY+tjlAIAdE0bg3oEdMP/Ry4IKJg1V+8lz11Q/COiD27PQs20jXBfwt3jdcUkm6tfSrxYJLEWH+hvr16oBCf3EqT5ZXnpOM0z8zQV44PLgqy/1caB1tVQ33VgTZHpa8D7/0Z19DH3Xy3seeGJUV9yhKnxEkkPiLWmTfWBpKivgR6hVw/9PD9eNc/of+2HcVZ0x+dYLUS89DQNVl6PXX1id6LylIyHg28Ff/u0FmP3QAL/lXXV+S2Rl6u8Yw7t5kuCnd/fBpN9cEPT5oE6e9fft0EQzUXp1blEfV/dohdd+19NvujfOUAehV6uGtfDM1d0wTimFPntNN/TJbII/BRyQ6pJi07o1sWPCCN/7sxoEt4c8cPk5IU9UenXrgerrHOShlt0gRCLTMm5EFww8LwPzHhmIV66vvorZ8vxwpGmcMDu1qI8UJQNq/T6DOmXgVz1aYli34CqpujWr7xIf1aMlbupTXWqumZaCd265EB0z6qFN4zrK8gXGjeiCjhnBJ+N1Tw31vVav64ouzQEA1+gk+ydHdUVllfHWbKMnYPXGOEtpI+ugijvwSiUSjQ3W2a958sqgaYFVcQDQqUUD3+ure7TC9w8O0CzwqJ+m9/frQ1/h2ikhe+NEql2TOvj3HRdhV1EJ9h87DQDof44nWf7jxp54cOpadFJKCj3aNsK6PccAeJL0Gwu2AQCyMpsgS6mKGPZMC0yaXd0LaFi3Fvhy1V4A1Un0kSvPw4y1+wB4Sr+dVTuO91LxTKV2z4GdE0dCSomcvcfRo20jzXlaNqztW84N7yzXnKdDs7pIS03BGzf1CvrMWzLTOseNH9UVeQUncONF7ZCaItC9tefS+/JOZ2Hmn/ujW6sGuK1fZtD32jSuLlX2bt8YQghsfm448gpO4uxmdTEz5wAAYO7DA5G966gvGXri8de1ZQO/9zP/3B+j/rkUN/cJri5Y/8wwZI6d5Xvv7RFiRcebc87yL1Gqh++IhLckWb9WDeycONIv/keHdsKzMzfiul6tcW0vT0Fi47PDUDM1RfPEEuj/hp6Hv/+w1Vc92LRuTRSdKtesTvKW/Ns1qeM3PSVFaFZj1KmZipLyyqAfbMC52nXx3sLEVd1b4vV52zDi/BbYuL8Y8zYdwtyHL0Nx6Zmwf48Wb1WZ186JIzFjjed4G9KlOZZvP4xT5ZVB3zuveb2gdoVZf+nv9372QwOwYnsRft55BADw1s29fW0+9wzogDHTc3Tjql0zuv0hHpI22XsTz1+HdcLogR1QIzUF3Vs39DU6DeumlGx6tkavto3RrqlnZ592b190emI26qen4fqstrg+K3SDUZeWDXBFF89l/+iBHVCrRqovCX+j7Hxei/46CAeOlxqKXwihm+h/UFU/AcBvLmzt2zHV/qTTEOlZfvXruy49G0O7NceN767A4yO64O6Aekg1b+LX0qxeOl64tjse/ybXN61WjdSg75zbvD7OVdpLvJfe7ZrWwX7Vtvk6oOGxe+uGyHvhKqSmCGw9dFI3BsDTsHgIZX6X19G4pGNTLNteZGjeyztlYOEW7RFZr7+wLXL3bUD2E0PwxS97gtqKAnl/G3XiCNfwqXbvZR3x9x+2+t5XKklbqx2rZ9tGeP+2LPQ/11Pl+NV9l6D4tCcBq0v2X93n+T1SNU4Ym54drpvk2jf1VHd2alHfd1y8fUtvlFVUoV56GhpG0YNm3VND/aqnvLyh1amZivdvvwgfLduBORsOAQBWjLsC6WkpaKwUxib99gJf0u7Wyn//7NyiATq3aKB5THn/eiNXxE6TtMm+VaPayHl6KOqnp/mVaDKb1cXm54ajVo3qndOb6AFPSW3RXweFrK8Eqmv+R13QEkIIzd403oKRd/Xtm9b17fzReubqbjgvIFn87qJ2+N1F7Xylw78MPgePDA3dwOW97B7UKQNP/srTOyfSHkFaAntIhHNe8/p455be6H9uBt5dtB1vLMgDAL/fx8tbqu3Uoj7+c/fF2HrohOYJMT0tNea/xfv9x6bnoHmIXhle792WhQqdao/b+rXHbf3aQwihWRdtNm9C9p7svElbK1EDwJCuzX2v1XXO3vr/Z67uFlQXrV5UYKL//sEBOHyyDOUVVZp12DVSU1Aj4Arlut6t8VPe4VB/FozmVwmgX8em6Nexqe+YaBHwG96Q1TZkCT0ZJW2yB/TrZbUSiZqRhOy9xDVSVWlm58Tb+unfJwB4GgvDJXrAU8pbMuZy0+8tiOampeHdPZfIjwzt5Ev2alonkP7nNvOVRgOlRNgSNXV0X6zadVTzs5d+G9xeoiUtNQV6NTp6vXH0XNurNZZtL8KDOg3U4aSkCOS/OMK3b/oKHRFul4z66dj+4gi/xtdurRtgRf4RpKUItG1SG+2bBB8rXQKq4Ix49YaeQdNaNayF/cdL8efB56Bvh6YYMz0HJ8oqImrw/vwPfVFaEVydY4T3hJQsg+smdbKPh1B9e8Nd6nlLWt7uh+Ou6ozdR0p05+/ZtlHIxLF6/JWaPQ30tA2opzWTWcfH3IcHommEVwtG+1s/MbILvs89iL4dmkbcx35Il7P87t40U52aaXjvtqyYlqFuD3njpp7454I81KuZhr+N6BxUbRFKYNXPu7dlYevBE6hTMw1Lxlh7/8nMvwzAoeJS38lj6ui+mL/pkO5Vd5+zPW1qv1d1A+3XUf93vfScpriul34vsqd+1Q1n1U/HlaorH29BMJJt6BRM9tEyUIL1lXJ1ck9Kin/1z72XddRdlrqkpiewn7TT/PL4EN1Gaa+fxg7G8ZLqRrtzw9Rve3lvle/asgGe+pWxm8buGdAB9yjdIyP1/u2R9cUP5/oL2/ga+bV6hsRicOfmGNzZk7BGD9Tfx4xoUKuGr6OC1ZrUrem3T7dtUgd3XOrfnrRs7GBfNZW604IRU+4J3fW5Sd2aeHyk/77U5+wm+O4vA9ClpXU3ylmFyT5GxqpxYi/npiTItaT3/oOBGndKGqkyat2odlR3gS4fdwVOlVVYerVipd7tG+PLVXsjuoOUrLnhMZyurYKrqT65q4+je+IATPZRu/eyjth79LTunYOA4fakpHJu8/pYPf5Kw32ezRJYCiSKJ63CjdMw2UepSd2aeOv3vUPO4x1bo266s8/4ZmPSjZy3T3xgLxUiszDZW+iFa8/H4M5n4YI2jewOhRzu171aI6/wZFy6ZpI7MdlbqF56Gq7pqX07OpFajdQUjLsqtkHRiELhNSMRkQsw2RMRuQCTPRGRCzDZExG5AJM9EZELMNkTEbkAkz0RkQsw2RMRuYAI9+zVqBYqRCGAXVF+vRmAME8xcJxEjBlIzLgZc/wkYtyJGDPgibuulNKyQXYsSfaxEEJkSyljG8w7zhIxZiAx42bM8ZOIcSdizEB84mY1DhGRCzDZExG5gBOT/bt2BxCFRIwZSMy4GXP8JGLciRgzEIe4HVdnT0RE5nNiyZ6IiEzGZE9E5AKOSfZCiOFCiC1CiDwhxFgb1t9WCLFQCLFRCLFBCPGgMr2JEGKuEGKb8n9jZboQQryhxJsjhOitWtbtyvzbhBC3q6ZfKIRYr3znDSGMPK7cUOypQog1QoiZyvuzhRArlfV8IYSoqUxPV97nKZ9nqpYxTpm+RQgxTDXdkt9FCNFICDFdCLFZCLFJCNHP6dtaCPGwsm/kCiE+F0LUcuK2FkL8WwhRIITIVU2zfNvqrSPGuF9W9pEcIcQ3QohGqs8i2o7R/FbRxKz67FEhhBRCNFPe27utpZS2/wOQCmA7gA4AagJYB6BrnGNoCaC38ro+gK0AugKYBGCsMn0sgJeU1yMAfA9AAOgLYKUyvQmAfOX/xsrrxspnPyvzCuW7V5kU+yMAPgMwU3k/DcCNyut3ANynvL4fwDvK6xsBfKG87qps83QAZyu/RaqVvwuAjwHco7yuCaCRk7c1gNYAdgCordrGdzhxWwMYCKA3gFzVNMu3rd46Yox7KIA05fVLqrgj3o6R/lbRxqxMbwtgDjw3lzZzwraOWzINs8H6AZijej8OwDibY/oWwJUAtgBoqUxrCWCL8noygJtU829RPr8JwGTV9MnKtJYANqum+80XQ5xtAMwHMBjATGWnOKw6QHzbVtn5+imv05T5ROD29s5n1e8CoCE8iVMETHfstoYn2e9RDsg0ZVsPc+q2BpAJ/6Rp+bbVW0cscQd8di2AKVrbJ9x2jOa4iCVmANMB9ACwE9XJ3tZt7ZRqHO+B5LVXmWYL5TKuF4CVAJpLKQ8oHx0E0Fx5rRdzqOl7NabH6nUAYwBUKe+bAjgmpazQWI8vNuXz48r8kf4tsTobQCGAD4Wn+ul9IURdOHhbSyn3Afg7gN0ADsCz7VbB+dvaKx7bVm8dZrkLntItwsSnNT2a4yIqQohrAOyTUq4L+MjWbe2UZO8YQoh6AL4C8JCUslj9mfScRh3TV1UIMQpAgZRyld2xRCgNnkvft6WUvQCcgudS1MeB27oxgGvgOVG1AlAXwHBbg4pSPLat2esQQjwOoALAFLOWaQUhRB0AfwPwZLzWaXRbOyXZ74OnjsurjTItroQQNeBJ9FOklF8rkw8JIVoqn7cEUKBM14s51PQ2GtNjcSmAq4UQOwFMhacq5x8AGgkh0jTW44tN+bwhgKIo/pZY7QWwV0q5Unk/HZ7k7+RtPQTADilloZTyDICv4dn+Tt/WXvHYtnrriIkQ4g4AowD8Xkls0cRdhMh/q2h0hKdAsE45LtsAWC2EaBFFzOZu62jrBM38B09JL1/ZSN5GlW5xjkEA+ATA6wHTX4Z/Q8gk5fVI+De2/KxMbwJPfXRj5d8OAE2UzwIbW0aYGP8gVDfQfgn/hqj7ldcPwL8hapryuhv8G7vy4Wnosux3AbAEQCfl9dPKdnbstgZwMYANAOooy/wYwJ+duq0RXGdv+bbVW0eMcQ8HsBFARsB8EW/HSH+raGMO+Gwnquvsbd3WcUumBjbYCHh6wGwH8LgN6+8Pz6VQDoC1yr8R8NTdzQewDcA81Y8gALylxLseQJZqWXcByFP+3amangUgV/nOm4igEchA/INQnew7KDtJnrKDpyvTaynv85TPO6i+/7gS1xaoeq5Y9bsA6AkgW9neM5Sd3NHbGsAzADYry/0UnkTjuG0N4HN42hXOwHMVdXc8tq3eOmKMOw+e+uy1yr93ot2O0fxW0cQc8PlOVCd7W7c1h0sgInIBp9TZExGRhZjsiYhcgMmeiMgFmOyJiFyAyZ6IyAWY7ImIXIDJnojIBf4f8W2oOg4OoHEAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(n_training_examples, losses)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's test how well this encoder works on the test data. To do this, we will iterate over all possible y outcomes for each example. We will pick the y that minimizes the loss. This would be the most compatible y. We can then compute an error rate from this.\n",
"\n",
"After this, we will then try to build a structural SVM on top to see how much gain we get from doing that."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For each example, we will map to 2^3 = 8 examples. That is, each possible outcome. We will then pick the outcome with lowest loss"
]
},
{
"cell_type": "code",
"execution_count": 221,
"metadata": {},
"outputs": [],
"source": [
"y_space = list(itertools.product([0, 1], repeat=n_classes))\n",
"y_space_dim = 2**n_classes\n",
"\n",
"n_train_examples = X_train.shape[0]"
]
},
{
"cell_type": "code",
"execution_count": 222,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.3484848484848485\n"
]
}
],
"source": [
"all_possible_errors = 0\n",
"n_errors = 0\n",
"\n",
"for i in range(X_test.shape[0]):\n",
" x = X_test[i]\n",
" y_true = y_test[i]\n",
" \n",
" X = np.tile(x, (y_space_dim, 1,))\n",
" X = np.concatenate((X, y_space), axis=1).astype('float32')\n",
" X = torch.tensor(X)\n",
" \n",
" X_out = auto_encoder(X)\n",
" mse = nn.MSELoss(reduction='none')(X_out, X).sum(axis=1)\n",
" \n",
" pred_index = torch.argmin(mse).item()\n",
" y_pred = y_space[pred_index] \n",
" n_errors += (y_pred != y_true).sum()\n",
" all_possible_errors += len(y_pred)\n",
"print(n_errors / all_possible_errors)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We now have a baseline. We also can this model to obtain out feature map:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's now attempt to train on top of this with a SSVM approach.\n",
"\n",
"We need to feed a few constraints into cvxopt:\n",
"\n",
"* Magnitude of weight vector should equal 1\n",
"\n",
"* Minimize squared norm of weight vector\n",
"\n",
"* Feed in each y constraint\n",
"\n",
"From this, we should get back a w. When we have this w, we remake predictions. Have they improved?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Constraints\n",
"\n",
"\n",
"\n",
"$\\text{min}_w \\frac{1}{2} ||w||^2$\n",
"\n",
"$ w \\cdot \\delta \\Psi_i(y) >= 1$\n",
"\n",
"At this point, let's stop and try this out. First, we need a model that outputs a x and y. I'm going to train a linear regression model that does that. All we want to \n",
"\n",
"\n",
"NEXT STEP:\n",
"\n",
"For each y in train. Compute the feature map difference between each possible y that is NOT y"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, we need to compute the deltas for each feature mapping."
]
},
{
"cell_type": "code",
"execution_count": 223,
"metadata": {},
"outputs": [],
"source": [
"# Construct delta matrix\n",
"deltas = []\n",
"for i in range(n_train_examples):\n",
" x = X_train[i]\n",
" y_true = y_train[i]\n",
" \n",
" X = np.concatenate((x, y_true)).reshape(1, -1).astype('float32')\n",
" X = torch.tensor(X)\n",
" psi_true = auto_encoder.get_feature_map(X).detach().numpy().flatten()\n",
" \n",
" for y in y_space:\n",
" if not np.array_equal(y, y_true):\n",
" # Combine \n",
" X = np.concatenate((x, y)).reshape(1, -1).astype('float32')\n",
" X = torch.tensor(X)\n",
" psi = auto_encoder.get_feature_map(X).detach().numpy().flatten()\n",
" deltas.append(psi - psi_true)\n",
"delta = np.array(deltas)"
]
},
{
"cell_type": "code",
"execution_count": 230,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"status: optimal\n",
"optimal value 0.09999687023506065\n",
"optimal var [ 1.28987733e-20 -2.50190525e-03 -9.71086140e-21]\n",
"optimal var [0.99970021 0.99999539 0.99970482 ... 0.99970021 0.99999539 0.99970482]\n"
]
}
],
"source": [
"import cvxpy as cp\n",
"from cvxpy.atoms import pnorm\n",
"from cvxpy.atoms.elementwise.power import power\n",
"\n",
"C = 0.1\n",
"n_constraints = delta.shape[0]\n",
"w = cp.Variable(n_hidden)\n",
"slack = cp.Variable(n_constraints, nonneg=True)\n",
"objective = cp.Minimize(0.5 * power(pnorm(w, p=1), 2) + (1. / n_constraints ) * C * cp.atoms.affine.sum.sum(slack))\n",
"constraints = [delta @ w >= 1 - slack]\n",
"prob = cp.Problem(objective, constraints)\n",
"prob.solve() # Returns the optimal value.\n",
"print(\"status:\", prob.status)\n",
"print(\"optimal value\", prob.value)\n",
"print(\"optimal var\", w.value)\n",
"print(\"optimal var\", slack.value)"
]
},
{
"cell_type": "code",
"execution_count": 231,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 1.28987733e-20, -2.50190525e-03, -9.71086140e-21])"
]
},
"execution_count": 231,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"w.value"
]
},
{
"cell_type": "code",
"execution_count": 232,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.6292929292929293\n"
]
}
],
"source": [
"# Lets test it out\n",
"# Construct delta matrix\n",
"all_possible_errors = 0\n",
"n_errors = 0\n",
"\n",
"for i in range(n_test_examples):\n",
" x = X_test[i]\n",
" y_true = y_test[i]\n",
" \n",
" preds = []\n",
" for y in y_space:\n",
" X = np.concatenate((x, y)).reshape(1, -1).astype('float32')\n",
" X = torch.tensor(X)\n",
" psi = auto_encoder.get_feature_map(X).detach().numpy().flatten()\n",
" out = w.value.dot(psi)\n",
" preds.append(out)\n",
" pred_index = np.argmax(preds)\n",
" y_pred = y_space[pred_index]\n",
" n_errors += (y_pred != y_true).sum()\n",
" all_possible_errors += len(y_pred)\n",
"print(n_errors / all_possible_errors)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment