Skip to content

Instantly share code, notes, and snippets.

@thistleknot
Last active December 6, 2023 05:19
Show Gist options
  • Save thistleknot/93481a4024e7813fa65d901f8c5fe599 to your computer and use it in GitHub Desktop.
Save thistleknot/93481a4024e7813fa65d901f8c5fe599 to your computer and use it in GitHub Desktop.
Dataset Distillation v3
import torch
import torch.nn.functional as F
from transformers import GPTNeoForCausalLM, AutoTokenizer
from datasets import load_dataset
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import random
# Parameters
NUM_EPOCHS = 100
LEARNING_RATE = 1e-4
NUM_DISTILLED_DATA = 10 # Number of synthetic data points
DISTILLED_SEQ_LEN = 21 # Length of sequences for distillation
EMBEDDING_SIZE = 768 # Adjust based on the model
EVAL_INTERVAL = 1
# Load GPT-Neo model and tokenizer
model_name = "EleutherAI/gpt-neo-125M"
model = GPTNeoForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# Function to extract embeddings
def extract_embeddings(data, tokenizer, model):
inputs = tokenizer(data, return_tensors="pt", padding=True, truncation=True)
outputs = model(**inputs, output_hidden_states=True)
return outputs.hidden_states[-1].detach()
# Custom GPT-Neo model for direct embedding input
class CustomGPTNeo(GPTNeoForCausalLM):
def forward(self, embeddings, labels=None):
transformer_outputs = self.transformer(inputs_embeds=embeddings, return_dict=True)
hidden_states = transformer_outputs.last_hidden_state
lm_logits = self.lm_head(hidden_states)
return lm_logits
# Function to get labels from logits
def get_labels_from_logits(logits, tokenizer):
probs = torch.softmax(logits, dim=-1)
_, predicted_token_ids = torch.max(probs, dim=-1)
return [tokenizer.decode(token_id) for token_id in predicted_token_ids]
# Evaluation function
def evaluate_model(model, synthetic_data, eval_embeddings, num_samples):
sampled_indices = random.sample(range(len(eval_embeddings)), num_samples)
sampled_eval_embeddings = torch.stack([eval_embeddings[i] for i in sampled_indices])
# Remove extra dimension if necessary
sampled_eval_embeddings = sampled_eval_embeddings.squeeze(1)
with torch.no_grad():
synthetic_logits = model(synthetic_data)
eval_logits = model(sampled_eval_embeddings)
eval_loss = F.kl_div(F.log_softmax(synthetic_logits, dim=-1), F.softmax(eval_logits, dim=-1), reduction='batchmean')
return eval_loss.item()
# Load and preprocess dataset
dataset = load_dataset("Abirate/english_quotes")
quotes = [item['quote'] for item in dataset['train']]
filtered_quotes = [q for q in quotes if len(tokenizer.encode(q, truncation=True)) == DISTILLED_SEQ_LEN]
train_quotes, eval_quotes = train_test_split(filtered_quotes, test_size=0.2, random_state=42)
# Extract embeddings
train_embeddings = [extract_embeddings(text, tokenizer, model) for text in train_quotes]
eval_embeddings = [extract_embeddings(text, tokenizer, model) for text in eval_quotes]
# Initialize synthetic dataset
synthetic_data = torch.randn(NUM_DISTILLED_DATA, DISTILLED_SEQ_LEN, EMBEDDING_SIZE, requires_grad=True)
# Optimization setup
custom_model = CustomGPTNeo.from_pretrained(model_name)
optimizer = torch.optim.Adam([synthetic_data], lr=LEARNING_RATE)
# Distillation process
for epoch in range(NUM_EPOCHS):
shuffled_indices = random.sample(range(len(train_embeddings)), len(train_embeddings))
total_loss = 0.0
num_batches = len(train_embeddings) // NUM_DISTILLED_DATA
for batch_idx in range(num_batches):
optimizer.zero_grad()
batch_indices = shuffled_indices[batch_idx * NUM_DISTILLED_DATA:(batch_idx + 1) * NUM_DISTILLED_DATA]
batch_train_embeddings = torch.stack([train_embeddings[i] for i in batch_indices])
# Remove extra dimension
batch_train_embeddings = batch_train_embeddings.squeeze(1) # Removing the second dimension
# Diagnostic print statements (optional, for confirmation)
#print("Adjusted batch_train_embeddings shape:", batch_train_embeddings.shape)
synthetic_logits = custom_model(synthetic_data)
batch_train_logits = custom_model(batch_train_embeddings)
loss = F.kl_div(F.log_softmax(synthetic_logits, dim=-1), F.softmax(batch_train_logits, dim=-1), reduction='batchmean')
total_loss += loss.item() # Accumulate the loss
# Print training loss every batch
print(f"Batch {batch_idx}: Training Loss {loss.item()}")
loss.backward()
optimizer.step()
# Calculate the average loss for the epoch
avg_training_loss = total_loss / num_batches
# Evaluate at the end of the epoch
if epoch % EVAL_INTERVAL == 0 and batch_idx == num_batches - 1:
eval_loss = evaluate_model(custom_model, synthetic_data, eval_embeddings, NUM_DISTILLED_DATA)
print(f"Epoch {epoch}: Average Training Loss {avg_training_loss}, Evaluation Loss {eval_loss}")
# Save synthetic dataset
synthetic_labels = get_labels_from_logits(custom_model(synthetic_data), tokenizer)
pd.DataFrame([' '.join(t) for t in synthetic_labels]).to_csv('distilled_dataset.csv')
@thistleknot
Copy link
Author

thistleknot commented Dec 5, 2023

image
Model distillation aims to distill the knowledge of a complex model into a simpler one. In this paper, we consider an alternative formulation called dataset distillation: we keep the model fixed and instead attempt to distill the knowledge from a large training dataset into a small one. The idea is to synthesize a small number of data points that do not need to come from the correct data distribution, but will, when given to the learning algorithm as training data, approximate the model trained on the original data. For example, we show that it is possible to compress 60,000 MNIST training images into just 10 synthetic distilled images (one per class) and achieve close to original performance with only a few gradient descent steps, given a fixed network initialization. We evaluate our method in various initialization settings and with different learning objectives. Experiments on multiple datasets show the advantage of our approach compared to alternative methods.

Terms:
- x~: Modified input data.
- L: Loss function that likely takes the entire dataset and learning rate into account.
- eta~: Modified learning rate.
- l: Loss function typically used in machine learning to measure discrepancies between predictions and true values.
- · : Function argument placeholder indicating the function takes certain arguments.
- X_t: Set of data points at time step t.
- x_{t,j}: The j-th data point in the set at time step t.
- n: Total number of data points in the set.
- η: Learning rate.
- η~: A modified learning rate used in update equations.
- theta_{t+1}: Parameter vector after the update at time step t+1.
- theta_t: Parameter vector at time step t.
- grad_theta_t: Gradient with respect to theta_t.
- Sum_j L^(j): Sum of the loss function L over all samples or iterations indexed by j.
- X~: Set or collection of modified data points.
- {x~_i}: Individual elements in the set, modified or processed in some way.
- M: The number of elements indexed from 1 to M in the set X~.
- arg min: Argument of the minimum, the values that minimize the function.
- E_{theta_0p(theta_0)}: Expected value over the distribution of initial parameters theta_0.
- theta_1 and theta_0: Parameter vectors after and before the update, respectively.
- d
: Modified matrix of input features.
- t~: Modified vector of target values.
- N: Number of data points.
- || d * theta - t ||^2: Squared Euclidean norm, mean squared error (MSE).
- alpha: Step size in update rules.
- grad_x~ and grad_eta~: Gradients with respect to x~ and eta~, respectively.
- I: Identity matrix.
- theta_i and theta_i+1: Parameter vectors at the ith and (i+1)th iteration, respectively.
- grad_theta_i: Gradient with respect to theta_i.
- x~_i: Modified data at the ith iteration.

3 APPROACH
Given a model and a dataset, we aim to obtain a new, much-reduced synthetic dataset which performs almost as well as the original dataset. We first present our main optimization algorithm for training a network with a fixed initialization with one gradient descent (GD) step (Section 3.1). In Section 3.2, we derive the resolution to a more challenging case, where initial weights are random rather than fixed. In Section 3.3, we further study a linear network case to help readers understand both the property and limitation of our method. We also discuss the initial weights distribution with which our method can work well. In Section 3.4, we extend our approach to more than one gradient descent steps and more than one epoch (pass). Finally, Section 3.5 and Section 3.6 demonstrate how to obtain distilled images with different initialization distributions and learning objectives.

Consider a training dataset x = {x_i}_{i=1}^N, we parameterize our neural network as theta and denote l(x_i; theta)
as the loss function that represents the loss of this network on a data point x_i. Our task is to find the
minimizer of the empirical error over the entire training data:

Equation 1

    theta* = arg min_theta (1/N * sum_{i=1}^N l(x_i; theta)) which is equivalent to arg min_theta l(x; theta).

where for notation simplicity we overload the l(·) notation so that l(x, theta) represents the average error of theta over the entire dataset. We make the mild assumption that l is twice-differentiable, which holds true for the majority of modern machine learning models and tasks.

3.1 OPTIMIZING DISTILLED DATA

Standard training usually applies minibatch stochastic gradient descent or its variants. At each step t, a minibatch of training data X_t = {x_{t,j}}_{j=1}^n is sampled

to update the current parameters as

theta_t+1 = theta_t - η * grad_theta l(x_t, theta_t)

where η is the learning rate. Such a training process often takes tens of thousands or even millions of update steps to converge. Instead, we aim to learn a tiny set of synthetic distilled training data X~ = {x~_i}_{i=1}^M

    In this notation:

with M << N and a corresponding learning rate η so that a single GD step such as

Equation 2

    theta_1 = theta_0 - η~ * grad_theta_0 l(x~, theta_0),

using these learned synthetic data x~ can greatly boost the performance on the real test set. Given an initial theta_0, we obtain these synthetic data x~ and learning rate η~ by minimizing the objective below L:

Equation 3

    x~*, η~* = arg min_{x~,η~} L(x~, η~; theta_0) = arg min_{x~,η~} l(x, theta_1) = arg min_{x~,η~} l(x, theta_0 - η~ * grad_theta_0 l(x~, theta_0)),

where we derive the new weights theta_1 as a function of distilled data x~ and learning rate η~ using Equation 2 and then evaluate the new weights over all the training data x. The loss L(x~, η~; theta_0) is differentiable w.r.t x~ and η~ and can thus be optimized using standard gradient-based methods. In many classification tasks, the data x may contain discrete parts, e.g., class labels in data-label pairs.
For such cases, we fix the discrete parts rather than learn them.

3.2 DISTILLATION FOR RANDOM INITIALIZATIONS

Unfortunately, the above distilled data optimized for a given initialization do not generalize well to other initializations. The distilled data often look like random noise (e.g., in Figure 2a) as it encodes the information of both training dataset x and a particular network initialization theta_0. To address this issue, we turn to calculate a small number of distilled data that can work for networks with random initializations from a specific distribution. We formulate the optimization problem as follows:

Equation 4

x~*, η~* = arg min_{x~,η~} E_{theta_0~p(theta_0)} L(x~, η~; theta_0),

where the network initialization theta_0 is randomly sampled from a distribution p(theta). During our optimization, the distilled data are optimized to work well for randomly initialized networks. Algorithm 1 illustrates our main method. In practice, we observe that the final distilled data generalize well to unseen initializations. In addition, these distilled images often look quite informative, encoding the
discriminative features of each category (e.g., in Figure 3).

For distilled data to be properly learned, it turns out crucial for l(x, ·) to share similar local conditions (e.g., output values, gradient magnitudes) over initializations theta_0 sampled from p(theta_0). In the next section, we derive a lower bound on the size of distilled data needed for a simple model with arbitrary initial theta_0, and discuss its implications on choosing p(theta_0).

3.3 ANALYSIS OF A SIMPLE LINEAR CASE

This section studies our formulation in a simple linear regression problem with quadratic loss. We derive the lower bound of the size of distilled data needed to achieve the same performance as training on the full dataset for arbitrary initialization with one GD step. Consider a dataset x containing N data-target pairs {(d_i, t_i)}_i=1^N, where d_i in R^D and t_i in R, which we represent as two matrices: an
N x D data matrix d and an N x 1 target matrix t. Given mean squared error and a D x 1 weight matrix theta, we have

Equation 5

    l(x, theta) = l((d, t), theta) = (1 / (2N)) || d * theta - t ||^2.

We aim to learn M synthetic data-target pairs x~ = (d~, t~), where d~ is an M x D matrix, t~ an M x 1 matrix (M << N), and η~ the learning rate, to minimize l(x, theta_0 - η * grad_theta l(x, theta_0)). The updated weight

Algorithm 1 Dataset Distillation

Input: p(theta_0): distribution of initial weights; M: the number of distilled data
Input: alpha: step size; n: batch size; T: the number of optimization iterations; η~_initial: initial value for η~
1. Initialize x~ = (x_i)_i=1^M randomly η~ <- η~_initial
2. For each training step t = 1 to T do
3.  Get a minibatch of real training data x_t = (x_t,j)_j=1^n
4.  Sample a batch of initial weights theta_0^(j) ~ p(theta_0)
5.  For each sampled theta_0^(j) do
6.      Compute updated parameter with GD: theta_1^(j) = theta_0^(j) - η~ * grad_theta_0^(j) l(x~^(j), theta_0^(j))
7.      Evaluate the objective function on real training data: L^(j) = l(x_t, theta_1^(j))
8. End for
9. Update x~ <- x~ - alpha * grad_x~ Sum_j L^(j), and η~ <- η~ - alpha * grad_eta~ Sum_j L^(j)
10. End for
Output: distilled data x and optimized learning rate η~

matrix after one GD step with these distilled data is

Equation 6

    theta_1 = theta_0 - η~ * grad_theta_0 l(x~, theta_0) = theta_0 - (η~ / M) * d~^T * (d~ * theta_0 - t~) = (I - (η~ / M) * d~^T * d~) * theta_0 + (η~ / M) * d~^T * t~.

For the quadratic loss, there always exists learned distilled data x that can achieve the same performance as training on the full dataset x (i.e., attaining the global minimum) for any initialization theta_0. For example, given any global minimum solution theta*, we can choose d = M^(-1)d' and theta = M^(-1)theta. But how small can the size of distilled data be? For such models, the global minimum is attained at any theta satisfying d' * d' * theta* = d' * t. Substituting Equation (6) in the condition above, we have

Equation 7

    d~^T * d * (I - (η~ / M) * d~^T * d~) * theta_0 + (η~ / M) * d~^T * d * d~^T * t~ = d~^T * t.

Here we make the mild assumption that the feature columns of the data matrix d are independent (i.e., d'^T * d_hat has full rank). For a x_hat = (d_hat, t_hat) to satisfy the above equation for any theta_0, we must have

Equation 8

    I - (η~ / M) * d~^T * d~ = 0,

which implies that dT * d has full rank and M >= D

Discussion. The analysis above only considers a simple case but suggests that any small number of distilled data fail to generalize to arbitrary initial theta_0. This is intuitively expected as the optimization target l(x, theta_1) = l(x, theta_0 - eta~ * grad_theta_0 l(x~, theta_0)) depends on the local behavior of l(x, ·) around theta_0, which can be drastically different across various initializations theta_0. The lower bound M >= D is quite restricting, considering that real datasets often have thousands to even hundreds of thousands of dimensions (e.g., images). This analysis motivates us to focus on p(theta_0) distributions that yield similar local conditions. Section 3.5 explores several practical choices. To address the limitation of using a single GD step, we extend our method to multiple GD steps in the next section.

3.4 MULTIPLE GRADIENT DESCENT STEPS AND MULTIPLE EPOCHS

We extend Algorithm 1 to more than one gradient descent steps by changing Line 6 to multiple sequential GD steps each on a different batch of distilled data and learning rates, i.e., each step is

Equation 9

    theta_i+1 = theta_i - η~_i * grad_theta_i l(x~_i, theta_i),

and changing Line 9 to backpropagate through all steps. However, naively computing gradients is memory and computationally expensive. Therefore, we exploit a recent technique called back-gradient optimization, which allows for significantly faster gradient calculation of such updates in reverse-mode differentiation. Specifically, back-gradient optimization formulates the necessary second-order terms into efficient Hessian-vector products (Pearlmutter, 1994), which can be easily calculated with modern automatic differentiation systems such as PyTorch (Paszke et al., 2017). For further algorithmic details, we refer the reader to prior work (Domke, 2012; Maclaurin et al., 2015).

Multiple epochs. To further improve the performance, we train the network for multiple epochs (passes) over the same sequence of distilled data. In other words, for each epoch, our method cycles through all GD steps, where each step is associated with a batch of distilled data. We do not tie the trained learning rates across epochs as later epochs often use smaller learning rates. In Section 4.1, we find that using multiple steps and multiple epochs is more effective than using just one on neural networks, with the total amount of distilled data fixed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment