Skip to content

Instantly share code, notes, and snippets.

@thistleknot
Created December 9, 2023 04:21
Show Gist options
  • Save thistleknot/e0c413d22e424be7cac362cc9e615e0f to your computer and use it in GitHub Desktop.
Save thistleknot/e0c413d22e424be7cac362cc9e615e0f to your computer and use it in GitHub Desktop.
Dataset Distill v4
#!/usr/bin/env python
# coding: utf-8
import torch
import torch.nn.functional as F
from transformers import GPTNeoForCausalLM, AutoTokenizer
from datasets import load_dataset
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import random
from scipy import stats
# Parameters
NUM_EPOCHS = 100
LEARNING_RATE = 1e-4
NUM_DISTILLED_DATA = 13 # Number of synthetic data points
DISTILLED_SEQ_LEN = 21 # Length of sequences for distillation
EMBEDDING_SIZE = 768 # Adjust based on the model
EVAL_INTERVAL = 1
# Load GPT-Neo model and tokenizer
model_name = "EleutherAI/gpt-neo-125M"
model = GPTNeoForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# Load and preprocess dataset
dataset = load_dataset("Abirate/english_quotes")
quotes = [item['quote'] for item in dataset['train']]
filtered_quotes = [q for q in quotes if len(tokenizer.encode(q, truncation=True)) == DISTILLED_SEQ_LEN]
train_quotes, eval_quotes = train_test_split(filtered_quotes, test_size=0.2, random_state=42)
#modify sequence length to measure
def calculate_ema_embeddings(sequence, sequence_length=21):
"""
Calculate the Exponential Moving Average (EMA) embeddings for a given sequence.
This function processes an input sequence to compute its EMA embeddings, which are used
to capture the evolving nature of embeddings in a sequence.
Parameters:
sequence (Tensor): The sequence of embeddings.
sequence_length (int): The length of the sequence for EMA calculation.
Returns:
Tensor: The EMA embedding of the input sequence.
"""
# Remove the batch dimension assuming it's always 1
sequence = sequence.squeeze(0) # shape now [sequence_length, embedding_size]
# Initialize EMA with the first token's embedding
alpha = 2.0 / (sequence_length + 1)
ema_embedding = sequence[0]
for i in range(1, sequence_length):
weight = 1 if i == sequence_length - 1 else alpha
ema_embedding = weight * sequence[i] + (1 - weight) * ema_embedding
return ema_embedding
def gaussian_percentiles(n):
"""
Generate n evenly spaced percentiles for a Gaussian distribution.
This function is used to determine the percentiles that correspond to a normal distribution,
which is crucial for aligning the distilled dataset with the statistical distribution of the original data.
Parameters:
n (int): The number of divisions in the Gaussian curve.
Returns:
list: A list of percentile values mapped onto a Gaussian distribution.
"""
# Evenly spaced percentiles
percentiles = np.linspace(0, 100, n+1)
# Mapping these percentiles to the Gaussian distribution
gaussian_values = [stats.norm.cdf(stats.norm.ppf(p/100)) * 100 for p in percentiles]
return gaussian_values
# Function to extract embeddings
def extract_embeddings(data, tokenizer, model):
inputs = tokenizer(data, return_tensors="pt", padding=True, truncation=True)
outputs = model(**inputs, output_hidden_states=True)
return outputs.hidden_states[-1].detach()
# Custom GPT-Neo model for direct embedding input
class CustomGPTNeo(GPTNeoForCausalLM):
def forward(self, embeddings, labels=None):
transformer_outputs = self.transformer(inputs_embeds=embeddings, return_dict=True)
hidden_states = transformer_outputs.last_hidden_state
lm_logits = self.lm_head(hidden_states)
return lm_logits
# Function to get labels from logits
def get_labels_from_logits(logits, tokenizer):
probs = torch.softmax(logits, dim=-1)
_, predicted_token_ids = torch.max(probs, dim=-1)
return [tokenizer.decode(token_id) for token_id in predicted_token_ids]
# Fluctuating Evaluation function: alternative would be to derive quantiles over all the eval data and have a static quantile set.
def evaluate_model(model, synthetic_data, sorted_eval_embeddings_dict, num_samples):
"""
Evaluate the model using synthetic data and a dictionary of sorted evaluation embeddings.
This function measures the performance of the distilled dataset against a subset of evaluation data.
It's crucial for assessing how well the synthetic dataset mimics the original data distribution.
Parameters:
model (Model): The model to be evaluated.
synthetic_data (Tensor): The synthetic data used for evaluation.
sorted_eval_embeddings_dict (dict): A dictionary of sorted evaluation embeddings.
num_samples (int): The number of samples to use for evaluation.
Returns:
float: The evaluation loss.
"""
sorted_sampled_indices = np.unique(random.sample(range(len(sorted_eval_embeddings_dict)), num_samples))
sampled_eval_embeddings = torch.stack([sorted_eval_embeddings_dict[i] for i in sorted_sampled_indices])
sampled_eval_embeddings = sampled_eval_embeddings.squeeze(1)
with torch.no_grad():
synthetic_logits = model(synthetic_data)
eval_logits = model(sampled_eval_embeddings)
eval_loss = F.kl_div(F.log_softmax(synthetic_logits, dim=-1), F.softmax(eval_logits, dim=-1), reduction='batchmean')
return eval_loss.item()
def embeddings_derive_z_scores(embeddings, centers, sdevs):
delta_center_embeddings = [e-centers for e in embeddings]
return delta_center_embeddings/sdevs
# Function to convert z-scores to percentiles
def z_score_to_percentile(z_scores):
# Apply the CDF of the standard normal distribution to each z-score
percentiles = stats.norm.cdf(z_scores) * 100 # Multiply by 100 to get percentiles
return percentiles
# Function to sort embeddings based on their quantile rankings and create index mapping
def sort_embeddings_and_create_mapping(percentiles):
sorted_indices = np.argsort(percentiles, axis=0)
return percentiles[sorted_indices], {original: sorted for sorted, original in enumerate(sorted_indices)}
def sorted_dict(sorted, percents, embeddings):
# Step 2: Find the position of each original embedding in the sorted list
index_to_rank = {}
means = [np.mean(s) for s in sorted]
for rank in range(len(sorted)):
rank_mean = np.mean(sorted[rank])
position = np.where(means==rank_mean)[0][0]
index_to_rank[rank] = embeddings[position]
return index_to_rank
def align_embeddings(embeddings, index_to_rank):
aligned_embeddings = [embeddings[index] for index in index_to_rank]
return torch.stack(aligned_embeddings)
def sort_embeddings(embeddings):
dataframes = []
for t in embeddings:
dataframes.append((pd.DataFrame(t[0])))
rows = pd.DataFrame()
for d in range(0,len(dataframes)):
df = dataframes[d]
collapsed_df = pd.melt(df, value_vars=df.columns, value_name='Value')
collapsed_df = collapsed_df[['Value']]
collapsed_df.reset_index(drop=True, inplace=True)
collapsed_df.columns = [d]
rows = pd.concat([rows,collapsed_df.T],axis=0)
sorted_embeddings = rows.sort_values(by=list(rows.columns))
return(sorted_embeddings.index)
def derive_quantiles(sorted_embeddings, percentiles):
"""
Derive quantiles from sorted embeddings based on specified percentiles.
This function aligns the distilled dataset with quantiles to ensure that it represents
the statistical characteristics of the original dataset, crucial for dataset distillation.
Parameters:
sorted_embeddings (list of Tensor): The sorted embeddings from the original dataset.
percentiles (list of float): The percentiles used to derive the quantiles.
Returns:
list: A list of quantiles derived from the embeddings.
"""
quantiles = []
n = len(sorted_embeddings) - 1 # Total number of records
for p in percentiles:
index = p / 100 * n # Calculate the unrounded index position
floor_index = int(np.floor(index))
ceil_index = int(np.ceil(index))
if floor_index == ceil_index:
quantile = sorted_embeddings[floor_index] # No averaging, select the record
else:
lower_ratio = ceil_index - index
upper_ratio = index - floor_index
quantile = (sorted_embeddings[floor_index] * lower_ratio) + (sorted_embeddings[ceil_index] * upper_ratio)
quantiles.append(quantile)
return quantiles
# Extract embeddings
train_embeddings = [extract_embeddings(text, tokenizer, model) for text in train_quotes]
sorted_train_index = sort_embeddings(train_embeddings)
sorted_train_dict = dict(zip(sorted_train_index,train_embeddings))
sorted_train_embeddings_dict = dict(zip(np.unique(list(sorted_train_dict.keys())),[sorted_train_dict[s] for s in np.unique(list(sorted_train_dict.keys()))]))
# Calculate the proportional quantiles based on the percentiles
# Example: Calculating for n=6 (similar to the Tukey's hinges distribution)
percentiles = gaussian_percentiles(NUM_DISTILLED_DATA-1)
#already sorted!
synthetic_embeddings = derive_quantiles([sorted_train_embeddings_dict[i] for i in list(sorted_train_embeddings_dict.keys())],percentiles)
#if I were to use the dict
synthetic_embeddings_dict = dict(zip(range(0,len(percentiles)),synthetic_embeddings))
# Extract embeddings
eval_embeddings = [extract_embeddings(text, tokenizer, model) for text in eval_quotes]
sorted_eval_index = sort_embeddings(eval_embeddings)
sorted_eval_dict = dict(zip(sorted_eval_index,eval_embeddings))
sorted_eval_embeddings_dict = dict(zip(np.unique(list(sorted_eval_dict.keys())),[sorted_eval_dict[s] for s in np.unique(list(sorted_eval_dict.keys()))]))
if(False):
train_ema_embeddings = [calculate_ema_embeddings(e) for e in train_embeddings]
train_center_ema_embeddings = np.mean(train_ema_embeddings,axis=0)
train_embeddings_sdevs = np.std(train_ema_embeddings,axis=0)
train_z_score_embeddings = embeddings_derive_z_scores(train_ema_embeddings,train_center_ema_embeddings,train_embeddings_sdevs)
# Convert each z-score in z_score_embeddings to percentiles
train_percentile_embeddings = [z_score_to_percentile(z_score) for z_score in train_z_score_embeddings]
# Sort the list based on the custom key function
sorted_training_percentiles = sorted(train_embeddings, key=custom_sort_key)
if(False):
train_index_to_rank = sorted_dict(sorted_training_percentiles, train_percentile_embeddings, train_embeddings)
eval_embeddings = [extract_embeddings(text, tokenizer, model) for text in eval_quotes]
eval_ema_embeddings = [calculate_ema_embeddings(e) for e in eval_embeddings]
eval_z_score_embeddings = embeddings_derive_z_scores(eval_ema_embeddings,train_center_ema_embeddings,train_embeddings_sdevs)
eval_percentile_embeddings = [z_score_to_percentile(z_score) for z_score in eval_z_score_embeddings]
sorted_eval_percentiles = np.sort(eval_percentile_embeddings)
eval_index_to_rank = sorted_dict(sorted_eval_percentiles, eval_percentile_embeddings, eval_embeddings)
# Initialize synthetic dataset
if(False):
synthetic_embeddings = np.percentile(train_embeddings,percentiles,axis=0)
#synthetic_embeddings = torch.randn(NUM_DISTILLED_DATA, DISTILLED_SEQ_LEN, EMBEDDING_SIZE, requires_grad=True)
#synthetic_z_score_embeddings = embeddings_derive_z_scores(synthetic_embeddings.detach().numpy(),train_center_ema_embeddings,train_embeddings_sdevs)
synthetic_z_score_embeddings = embeddings_derive_z_scores(synthetic_embeddings,train_center_ema_embeddings,train_embeddings_sdevs)
synthetic_percentile_embeddings = [z_score_to_percentile(z_score) for z_score in synthetic_z_score_embeddings]
sorted_synthetic_percentile_embeddings = np.sort(synthetic_percentile_embeddings)
sorted_synthetic_percentiles = np.sort(sorted_synthetic_percentile_embeddings)
synthetic_index_to_rank = sorted_dict(sorted_synthetic_percentiles, synthetic_percentile_embeddings, synthetic_embeddings)
# Optimization setup
custom_model = CustomGPTNeo.from_pretrained(model_name)
synthetic_embeddings = torch.tensor(torch.cat(synthetic_embeddings, dim=0).squeeze(1))
# Create an optimizer for the tensor
#optimizer = torch.optim.Adam(synthetic_embeddings, lr=LEARNING_RATE)
optimizer = torch.optim.Adam([synthetic_embeddings], lr=LEARNING_RATE)
# Distillation process
for epoch in range(NUM_EPOCHS):
shuffled_indices = random.sample(range(len(train_embeddings)), len(train_embeddings))
total_loss = 0.0
num_batches = len(train_embeddings) // NUM_DISTILLED_DATA
for batch_idx in range(num_batches):
optimizer.zero_grad()
sorted_batch_indices = np.unique(shuffled_indices[batch_idx * NUM_DISTILLED_DATA:(batch_idx + 1) * NUM_DISTILLED_DATA])
sorted_batch_embeddings = [sorted_train_embeddings_dict[i] for i in sorted_batch_indices]
batch_train_embeddings = torch.cat(sorted_batch_embeddings, dim=0)
batch_train_embeddings = batch_train_embeddings.squeeze(1)
# Remove extra dimension
batch_train_embeddings = batch_train_embeddings.squeeze(1) # Removing the second dimension
# Diagnostic print statements (optional, for confirmation)
#print("Adjusted batch_train_embeddings shape:", batch_train_embeddings.shape)
# Cast synthetic_embeddings to a torch tensor
synthetic_logits = custom_model(synthetic_embeddings)
batch_train_logits = custom_model(batch_train_embeddings)
loss = F.kl_div(F.log_softmax(synthetic_logits, dim=-1), F.softmax(batch_train_logits, dim=-1), reduction='batchmean')
total_loss += loss.item() # Accumulate the loss
# Print training loss every batch
print(f"Batch {batch_idx}: Training Loss {loss.item()}")
loss.backward()
optimizer.step()
# Calculate the average loss for the epoch
avg_training_loss = total_loss / num_batches
# Evaluate at the end of the epoch
if epoch % EVAL_INTERVAL == 0 and batch_idx == num_batches - 1:
eval_loss = evaluate_model(custom_model, synthetic_embeddings, sorted_eval_embeddings_dict, NUM_DISTILLED_DATA)
print(f"Epoch {epoch}: Average Training Loss {avg_training_loss}, Evaluation Loss {eval_loss}")
# Save synthetic dataset
synthetic_labels = get_labels_from_logits(custom_model(synthetic_embeddings), tokenizer)
pd.DataFrame([' '.join(t) for t in synthetic_labels]).to_csv('distilled_dataset.csv')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment