Last active
December 6, 2023 05:19
-
-
Save thistleknot/93481a4024e7813fa65d901f8c5fe599 to your computer and use it in GitHub Desktop.
Dataset Distillation v3
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
from transformers import GPTNeoForCausalLM, AutoTokenizer | |
from datasets import load_dataset | |
from sklearn.model_selection import train_test_split | |
import pandas as pd | |
import numpy as np | |
import random | |
# Parameters | |
NUM_EPOCHS = 100 | |
LEARNING_RATE = 1e-4 | |
NUM_DISTILLED_DATA = 10 # Number of synthetic data points | |
DISTILLED_SEQ_LEN = 21 # Length of sequences for distillation | |
EMBEDDING_SIZE = 768 # Adjust based on the model | |
EVAL_INTERVAL = 1 | |
# Load GPT-Neo model and tokenizer | |
model_name = "EleutherAI/gpt-neo-125M" | |
model = GPTNeoForCausalLM.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
tokenizer.pad_token = tokenizer.eos_token | |
# Function to extract embeddings | |
def extract_embeddings(data, tokenizer, model): | |
inputs = tokenizer(data, return_tensors="pt", padding=True, truncation=True) | |
outputs = model(**inputs, output_hidden_states=True) | |
return outputs.hidden_states[-1].detach() | |
# Custom GPT-Neo model for direct embedding input | |
class CustomGPTNeo(GPTNeoForCausalLM): | |
def forward(self, embeddings, labels=None): | |
transformer_outputs = self.transformer(inputs_embeds=embeddings, return_dict=True) | |
hidden_states = transformer_outputs.last_hidden_state | |
lm_logits = self.lm_head(hidden_states) | |
return lm_logits | |
# Function to get labels from logits | |
def get_labels_from_logits(logits, tokenizer): | |
probs = torch.softmax(logits, dim=-1) | |
_, predicted_token_ids = torch.max(probs, dim=-1) | |
return [tokenizer.decode(token_id) for token_id in predicted_token_ids] | |
# Evaluation function | |
def evaluate_model(model, synthetic_data, eval_embeddings, num_samples): | |
sampled_indices = random.sample(range(len(eval_embeddings)), num_samples) | |
sampled_eval_embeddings = torch.stack([eval_embeddings[i] for i in sampled_indices]) | |
# Remove extra dimension if necessary | |
sampled_eval_embeddings = sampled_eval_embeddings.squeeze(1) | |
with torch.no_grad(): | |
synthetic_logits = model(synthetic_data) | |
eval_logits = model(sampled_eval_embeddings) | |
eval_loss = F.kl_div(F.log_softmax(synthetic_logits, dim=-1), F.softmax(eval_logits, dim=-1), reduction='batchmean') | |
return eval_loss.item() | |
# Load and preprocess dataset | |
dataset = load_dataset("Abirate/english_quotes") | |
quotes = [item['quote'] for item in dataset['train']] | |
filtered_quotes = [q for q in quotes if len(tokenizer.encode(q, truncation=True)) == DISTILLED_SEQ_LEN] | |
train_quotes, eval_quotes = train_test_split(filtered_quotes, test_size=0.2, random_state=42) | |
# Extract embeddings | |
train_embeddings = [extract_embeddings(text, tokenizer, model) for text in train_quotes] | |
eval_embeddings = [extract_embeddings(text, tokenizer, model) for text in eval_quotes] | |
# Initialize synthetic dataset | |
synthetic_data = torch.randn(NUM_DISTILLED_DATA, DISTILLED_SEQ_LEN, EMBEDDING_SIZE, requires_grad=True) | |
# Optimization setup | |
custom_model = CustomGPTNeo.from_pretrained(model_name) | |
optimizer = torch.optim.Adam([synthetic_data], lr=LEARNING_RATE) | |
# Distillation process | |
for epoch in range(NUM_EPOCHS): | |
shuffled_indices = random.sample(range(len(train_embeddings)), len(train_embeddings)) | |
total_loss = 0.0 | |
num_batches = len(train_embeddings) // NUM_DISTILLED_DATA | |
for batch_idx in range(num_batches): | |
optimizer.zero_grad() | |
batch_indices = shuffled_indices[batch_idx * NUM_DISTILLED_DATA:(batch_idx + 1) * NUM_DISTILLED_DATA] | |
batch_train_embeddings = torch.stack([train_embeddings[i] for i in batch_indices]) | |
# Remove extra dimension | |
batch_train_embeddings = batch_train_embeddings.squeeze(1) # Removing the second dimension | |
# Diagnostic print statements (optional, for confirmation) | |
#print("Adjusted batch_train_embeddings shape:", batch_train_embeddings.shape) | |
synthetic_logits = custom_model(synthetic_data) | |
batch_train_logits = custom_model(batch_train_embeddings) | |
loss = F.kl_div(F.log_softmax(synthetic_logits, dim=-1), F.softmax(batch_train_logits, dim=-1), reduction='batchmean') | |
total_loss += loss.item() # Accumulate the loss | |
# Print training loss every batch | |
print(f"Batch {batch_idx}: Training Loss {loss.item()}") | |
loss.backward() | |
optimizer.step() | |
# Calculate the average loss for the epoch | |
avg_training_loss = total_loss / num_batches | |
# Evaluate at the end of the epoch | |
if epoch % EVAL_INTERVAL == 0 and batch_idx == num_batches - 1: | |
eval_loss = evaluate_model(custom_model, synthetic_data, eval_embeddings, NUM_DISTILLED_DATA) | |
print(f"Epoch {epoch}: Average Training Loss {avg_training_loss}, Evaluation Loss {eval_loss}") | |
# Save synthetic dataset | |
synthetic_labels = get_labels_from_logits(custom_model(synthetic_data), tokenizer) | |
pd.DataFrame([' '.join(t) for t in synthetic_labels]).to_csv('distilled_dataset.csv') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Model distillation aims to distill the knowledge of a complex model into a simpler one. In this paper, we consider an alternative formulation called dataset distillation: we keep the model fixed and instead attempt to distill the knowledge from a large training dataset into a small one. The idea is to synthesize a small number of data points that do not need to come from the correct data distribution, but will, when given to the learning algorithm as training data, approximate the model trained on the original data. For example, we show that it is possible to compress 60,000 MNIST training images into just 10 synthetic distilled images (one per class) and achieve close to original performance with only a few gradient descent steps, given a fixed network initialization. We evaluate our method in various initialization settings and with different learning objectives. Experiments on multiple datasets show the advantage of our approach compared to alternative methods.
Terms:
- x~: Modified input data.
- L: Loss function that likely takes the entire dataset and learning rate into account.
- eta~: Modified learning rate.
- l: Loss function typically used in machine learning to measure discrepancies between predictions and true values.
- · : Function argument placeholder indicating the function takes certain arguments.
- X_t: Set of data points at time step t.
- x_{t,j}: The j-th data point in the set at time step t.
- n: Total number of data points in the set.
- η: Learning rate.
- η~: A modified learning rate used in update equations.
- theta_{t+1}: Parameter vector after the update at time step t+1.
- theta_t: Parameter vector at time step t.
- grad_theta_t: Gradient with respect to theta_t.
- Sum_j L^(j): Sum of the loss function L over all samples or iterations indexed by j.
- X~: Set or collection of modified data points.
- {x~_i}: Individual elements in the set, modified or processed in some way.
- M: The number of elements indexed from 1 to M in the set X~.
- arg min: Argument of the minimum, the values that minimize the function.
- E_{theta_0
p(theta_0)}: Expected value over the distribution of initial parameters theta_0.: Modified matrix of input features.- theta_1 and theta_0: Parameter vectors after and before the update, respectively.
- d
- t~: Modified vector of target values.
- N: Number of data points.
- || d * theta - t ||^2: Squared Euclidean norm, mean squared error (MSE).
- alpha: Step size in update rules.
- grad_x~ and grad_eta~: Gradients with respect to x~ and eta~, respectively.
- I: Identity matrix.
- theta_i and theta_i+1: Parameter vectors at the ith and (i+1)th iteration, respectively.
- grad_theta_i: Gradient with respect to theta_i.
- x~_i: Modified data at the ith iteration.
3 APPROACH
Given a model and a dataset, we aim to obtain a new, much-reduced synthetic dataset which performs almost as well as the original dataset. We first present our main optimization algorithm for training a network with a fixed initialization with one gradient descent (GD) step (Section 3.1). In Section 3.2, we derive the resolution to a more challenging case, where initial weights are random rather than fixed. In Section 3.3, we further study a linear network case to help readers understand both the property and limitation of our method. We also discuss the initial weights distribution with which our method can work well. In Section 3.4, we extend our approach to more than one gradient descent steps and more than one epoch (pass). Finally, Section 3.5 and Section 3.6 demonstrate how to obtain distilled images with different initialization distributions and learning objectives.
Consider a training dataset x = {x_i}_{i=1}^N, we parameterize our neural network as theta and denote l(x_i; theta)
as the loss function that represents the loss of this network on a data point x_i. Our task is to find the
minimizer of the empirical error over the entire training data:
where for notation simplicity we overload the l(·) notation so that l(x, theta) represents the average error of theta over the entire dataset. We make the mild assumption that l is twice-differentiable, which holds true for the majority of modern machine learning models and tasks.
3.1 OPTIMIZING DISTILLED DATA
Standard training usually applies minibatch stochastic gradient descent or its variants. At each step t, a minibatch of training data X_t = {x_{t,j}}_{j=1}^n is sampled
to update the current parameters as
where η is the learning rate. Such a training process often takes tens of thousands or even millions of update steps to converge. Instead, we aim to learn a tiny set of synthetic distilled training data X~ = {x~_i}_{i=1}^M
with M << N and a corresponding learning rate η so that a single GD step such as
using these learned synthetic data x~ can greatly boost the performance on the real test set. Given an initial theta_0, we obtain these synthetic data x~ and learning rate η~ by minimizing the objective below L:
where we derive the new weights theta_1 as a function of distilled data x~ and learning rate η~ using Equation 2 and then evaluate the new weights over all the training data x. The loss L(x~, η~; theta_0) is differentiable w.r.t x~ and η~ and can thus be optimized using standard gradient-based methods. In many classification tasks, the data x may contain discrete parts, e.g., class labels in data-label pairs.
For such cases, we fix the discrete parts rather than learn them.
3.2 DISTILLATION FOR RANDOM INITIALIZATIONS
Unfortunately, the above distilled data optimized for a given initialization do not generalize well to other initializations. The distilled data often look like random noise (e.g., in Figure 2a) as it encodes the information of both training dataset x and a particular network initialization theta_0. To address this issue, we turn to calculate a small number of distilled data that can work for networks with random initializations from a specific distribution. We formulate the optimization problem as follows:
where the network initialization theta_0 is randomly sampled from a distribution p(theta). During our optimization, the distilled data are optimized to work well for randomly initialized networks. Algorithm 1 illustrates our main method. In practice, we observe that the final distilled data generalize well to unseen initializations. In addition, these distilled images often look quite informative, encoding the
discriminative features of each category (e.g., in Figure 3).
For distilled data to be properly learned, it turns out crucial for l(x, ·) to share similar local conditions (e.g., output values, gradient magnitudes) over initializations theta_0 sampled from p(theta_0). In the next section, we derive a lower bound on the size of distilled data needed for a simple model with arbitrary initial theta_0, and discuss its implications on choosing p(theta_0).
3.3 ANALYSIS OF A SIMPLE LINEAR CASE
This section studies our formulation in a simple linear regression problem with quadratic loss. We derive the lower bound of the size of distilled data needed to achieve the same performance as training on the full dataset for arbitrary initialization with one GD step. Consider a dataset x containing N data-target pairs {(d_i, t_i)}_i=1^N, where d_i in R^D and t_i in R, which we represent as two matrices: an
N x D data matrix d and an N x 1 target matrix t. Given mean squared error and a D x 1 weight matrix theta, we have
We aim to learn M synthetic data-target pairs x~ = (d~, t~), where d~ is an M x D matrix, t~ an M x 1 matrix (M << N), and η~ the learning rate, to minimize l(x, theta_0 - η * grad_theta l(x, theta_0)). The updated weight
Algorithm 1 Dataset Distillation
matrix after one GD step with these distilled data is
For the quadratic loss, there always exists learned distilled data x that can achieve the same performance as training on the full dataset x (i.e., attaining the global minimum) for any initialization theta_0. For example, given any global minimum solution theta*, we can choose d = M^(-1)d' and theta = M^(-1)theta. But how small can the size of distilled data be? For such models, the global minimum is attained at any theta satisfying d' * d' * theta* = d' * t. Substituting Equation (6) in the condition above, we have
Here we make the mild assumption that the feature columns of the data matrix d are independent (i.e., d'^T * d_hat has full rank). For a x_hat = (d_hat, t_hat) to satisfy the above equation for any theta_0, we must have
which implies that d
T * dhas full rank and M >= DDiscussion. The analysis above only considers a simple case but suggests that any small number of distilled data fail to generalize to arbitrary initial theta_0. This is intuitively expected as the optimization target l(x, theta_1) = l(x, theta_0 - eta~ * grad_theta_0 l(x~, theta_0)) depends on the local behavior of l(x, ·) around theta_0, which can be drastically different across various initializations theta_0. The lower bound M >= D is quite restricting, considering that real datasets often have thousands to even hundreds of thousands of dimensions (e.g., images). This analysis motivates us to focus on p(theta_0) distributions that yield similar local conditions. Section 3.5 explores several practical choices. To address the limitation of using a single GD step, we extend our method to multiple GD steps in the next section.
3.4 MULTIPLE GRADIENT DESCENT STEPS AND MULTIPLE EPOCHS
We extend Algorithm 1 to more than one gradient descent steps by changing Line 6 to multiple sequential GD steps each on a different batch of distilled data and learning rates, i.e., each step is
and changing Line 9 to backpropagate through all steps. However, naively computing gradients is memory and computationally expensive. Therefore, we exploit a recent technique called back-gradient optimization, which allows for significantly faster gradient calculation of such updates in reverse-mode differentiation. Specifically, back-gradient optimization formulates the necessary second-order terms into efficient Hessian-vector products (Pearlmutter, 1994), which can be easily calculated with modern automatic differentiation systems such as PyTorch (Paszke et al., 2017). For further algorithmic details, we refer the reader to prior work (Domke, 2012; Maclaurin et al., 2015).
Multiple epochs. To further improve the performance, we train the network for multiple epochs (passes) over the same sequence of distilled data. In other words, for each epoch, our method cycles through all GD steps, where each step is associated with a batch of distilled data. We do not tie the trained learning rates across epochs as later epochs often use smaller learning rates. In Section 4.1, we find that using multiple steps and multiple epochs is more effective than using just one on neural networks, with the total amount of distilled data fixed.