Skip to content

Instantly share code, notes, and snippets.

@thistleknot
Last active December 9, 2023 23:10
Show Gist options
  • Save thistleknot/57ef8e63de0f2b6df5753e7d7b305864 to your computer and use it in GitHub Desktop.
Save thistleknot/57ef8e63de0f2b6df5753e7d7b305864 to your computer and use it in GitHub Desktop.
dataset distillation v5
#!/usr/bin/env python
# coding: utf-8
import torch
import torch.nn.functional as F
from transformers import GPTNeoForCausalLM, AutoTokenizer
from datasets import load_dataset
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import random
from scipy import stats
import math
import wandb
import os
from torch.optim.lr_scheduler import LambdaLR
import torch.nn.functional as F
from sklearn.metrics.pairwise import cosine_similarity
os.environ["WANDB_MODE"]="offline"
if torch.cuda.is_available():
print("CUDA is available. Using GPU.")
device = torch.device("cuda")
else:
print("CUDA is not available. Using CPU.")
device = torch.device("cpu")
# Parameters
#has to be bigger than warmup_steps
NUM_EPOCHS = 10
NUM_DISTILLED_DATA = 13 # Number of synthetic data points
DISTILLED_SEQ_LEN = 21 # Length of sequences for distillation
EVAL_INTERVAL = 1
# Parameters
LEARNING_RATE = 1e-3
peak_lr = LEARNING_RATE*2
desired_lr = LEARNING_RATE / 20
WEIGHT_DECAY=int(NUM_EPOCHS/10)
# Add a warm-up phase
warmup_epochs = 5 # Adjust as needed
# Load GPT-Neo model and tokenizer
model_name = "EleutherAI/gpt-neo-125M"
model = GPTNeoForCausalLM.from_pretrained(model_name)
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# Load and preprocess dataset
dataset = load_dataset("Abirate/english_quotes")
quotes = [item['quote'] for item in dataset['train']]
filtered_quotes = [q for q in quotes if len(tokenizer.encode(q, truncation=True)) == DISTILLED_SEQ_LEN]
train_quotes, eval_quotes = train_test_split(filtered_quotes, test_size=0.2, random_state=42)
# Create a lambda function for the warm-up and cosine annealing schedule
def lr_lambda(epoch):
if epoch < warmup_epochs:
return LEARNING_RATE + (peak_lr - LEARNING_RATE) * (epoch + 1) / warmup_epochs
else:
return desired_lr + 0.5 * (peak_lr - desired_lr) * (1 + math.cos(math.pi * (epoch - warmup_epochs) / (NUM_EPOCHS - warmup_epochs)))
#modify sequence length to measure
def calculate_ema_embeddings(sequence, sequence_length=21):
"""
Calculate the Exponential Moving Average (EMA) embeddings for a given sequence.
This function processes an input sequence to compute its EMA embeddings, which are used
to capture the evolving nature of embeddings in a sequence.
Parameters:
sequence (Tensor): The sequence of embeddings.
sequence_length (int): The length of the sequence for EMA calculation.
Returns:
Tensor: The EMA embedding of the input sequence.
"""
# Remove the batch dimension assuming it's always 1
sequence = sequence.squeeze(0) # shape now [sequence_length, embedding_size]
# Initialize EMA with the first token's embedding
alpha = 2.0 / (sequence_length + 1)
ema_embedding = sequence[0]
for i in range(1, sequence_length):
weight = 1 if i == sequence_length - 1 else alpha
ema_embedding = weight * sequence[i] + (1 - weight) * ema_embedding
return ema_embedding
def gaussian_percentiles(n):
"""
Generate n evenly spaced percentiles for a Gaussian distribution.
This function is used to determine the percentiles that correspond to a normal distribution,
which is crucial for aligning the distilled dataset with the statistical distribution of the original data.
Parameters:
n (int): The number of divisions in the Gaussian curve.
Returns:
list: A list of percentile values mapped onto a Gaussian distribution.
"""
# Evenly spaced percentiles
percentiles = np.linspace(0, 100, n+1)
# Mapping these percentiles to the Gaussian distribution
gaussian_values = [stats.norm.cdf(stats.norm.ppf(p/100)) * 100 for p in percentiles]
return gaussian_values
def extract_embeddings(input_tokens, tokenizer, model):
"""
Generate tokens from embeddings using GPT-Neo model.
Args:
model (GPTNeoForCausalLM): GPT-Neo model.
tokenizer (GPT2Tokenizer): Tokenizer for GPT-Neo.
input_tokens (list): List of input tokens.
Returns:
list: List of generated tokens.
"""
# Step 1: Convert input tokens to input IDs
input_ids = tokenizer(input_tokens, return_tensors='pt').input_ids.to(device)
# Step 2: Get embeddings for input IDs
embeddings = model.get_input_embeddings()(input_ids)
#inputs = tokenizer(data, return_tensors="pt", padding=True, truncation=True)
#inputs = {k: v.to(device) for k, v in inputs.items()} # Move inputs to the device
#outputs = model(**inputs, output_hidden_states=True)
#return outputs.hidden_states[-1].detach()
return embeddings.detach()
class CustomGPTNeo(GPTNeoForCausalLM):
def forward(self, embeddings, labels=None):
# Move embeddings to the same device as the model
embeddings = embeddings.to(self.transformer.wte.weight.device)
# Forward pass
transformer_outputs = self.transformer(inputs_embeds=embeddings, return_dict=True)
hidden_states = transformer_outputs.last_hidden_state
lm_logits = self.lm_head(hidden_states)
return lm_logits
# Fluctuating Evaluation function: alternative would be to derive quantiles over all the eval data and have a static quantile set.
def evaluate_model(model, synthetic_data, sorted_eval_embeddings_dict, num_samples):
"""
Evaluate the model using synthetic data and a dictionary of sorted evaluation embeddings.
This function measures the performance of the distilled dataset against a subset of evaluation data.
It's crucial for assessing how well the synthetic dataset mimics the original data distribution.
Parameters:
model (Model): The model to be evaluated.
synthetic_data (Tensor): The synthetic data used for evaluation.
sorted_eval_embeddings_dict (dict): A dictionary of sorted evaluation embeddings.
num_samples (int): The number of samples to use for evaluation.
Returns:
float: The evaluation loss.
"""
sorted_sampled_indices = np.unique(random.sample(range(len(sorted_eval_embeddings_dict)), num_samples))
sampled_eval_embeddings = torch.stack([sorted_eval_embeddings_dict[i] for i in sorted_sampled_indices])
sampled_eval_embeddings = sampled_eval_embeddings.squeeze(1)
with torch.no_grad():
synthetic_logits = model(synthetic_data)
eval_logits = model(sampled_eval_embeddings)
eval_loss = F.kl_div(F.log_softmax(synthetic_logits, dim=-1), F.softmax(eval_logits, dim=-1), reduction='batchmean')
return eval_loss.item()
def embeddings_derive_z_scores(embeddings, centers, sdevs):
delta_center_embeddings = [e-centers for e in embeddings]
return delta_center_embeddings/sdevs
# Function to convert z-scores to percentiles
def z_score_to_percentile(z_scores):
# Apply the CDF of the standard normal distribution to each z-score
percentiles = stats.norm.cdf(z_scores) * 100 # Multiply by 100 to get percentiles
return percentiles
# Function to sort embeddings based on their quantile rankings and create index mapping
def sort_embeddings_and_create_mapping(percentiles):
sorted_indices = np.argsort(percentiles, axis=0)
return percentiles[sorted_indices], {original: sorted for sorted, original in enumerate(sorted_indices)}
def sorted_dict(sorted, percents, embeddings):
# Step 2: Find the position of each original embedding in the sorted list
index_to_rank = {}
means = [np.mean(s) for s in sorted]
for rank in range(len(sorted)):
rank_mean = np.mean(sorted[rank])
position = np.where(means==rank_mean)[0][0]
index_to_rank[rank] = embeddings[position]
return index_to_rank
def align_embeddings(embeddings, index_to_rank):
aligned_embeddings = [embeddings[index] for index in index_to_rank]
return torch.stack(aligned_embeddings)
def sort_embeddings(embeddings):
dataframes = []
for t in embeddings:
# First ensure the tensor is on CPU
t_cpu = t[0].cpu()
# Now convert the CPU tensor to a DataFrame
dataframes.append(pd.DataFrame(t_cpu.numpy()))
rows = pd.DataFrame()
for d in range(0,len(dataframes)):
df = dataframes[d]
collapsed_df = pd.melt(df, value_vars=df.columns, value_name='Value')
collapsed_df = collapsed_df[['Value']]
collapsed_df.reset_index(drop=True, inplace=True)
collapsed_df.columns = [d]
rows = pd.concat([rows,collapsed_df.T],axis=0)
sorted_embeddings = rows.sort_values(by=list(rows.columns))
return(sorted_embeddings.index)
def sort_embeddings(embeddings):
averages = []
for t in embeddings:
# Ensure the tensor is on CPU and convert it to a DataFrame
t_cpu = t[0].cpu()
df = pd.DataFrame(t_cpu.numpy())
# Compute the mean of means for each embedding
mean_of_means = df.mean(axis=1).mean()
averages.append(mean_of_means)
# Create a Series from the averages
average_series = pd.Series(averages)
# Get the sorted indices
sorted_indices = average_series.sort_values().index
return sorted_indices
def calculate_ema(series, alpha):
"""
Calculate the Exponential Moving Average (EMA) for a given series.
Parameters:
series (pd.Series): The series of values for which EMA is to be calculated.
alpha (float): The smoothing factor for EMA.
Returns:
float: The EMA of the input series.
"""
ema = series[0]
for value in series[1:]:
ema = alpha * value + (1 - alpha) * ema
return ema
def sort_embeddings_ema(embeddings):
ema_values = []
for t in embeddings:
# Ensure the tensor is on CPU and convert it to a DataFrame
t_cpu = t[0].cpu()
df = pd.DataFrame(t_cpu.numpy())
# Calculate alpha for EMA
n = len(df)
alpha = (n - 1) / n
# Compute the EMA for each embedding
ema = df.apply(lambda row: calculate_ema(row, alpha), axis=1)
mean_ema = ema.mean()
ema_values.append(mean_ema)
# Create a Series from the EMA values
ema_series = pd.Series(ema_values)
# Get the sorted indices
sorted_indices = ema_series.sort_values().index
return sorted_indices
def derive_quantiles(sorted_embeddings, percentiles):
"""
Derive quantiles from sorted embeddings based on specified percentiles.
This function aligns the distilled dataset with quantiles to ensure that it represents
the statistical characteristics of the original dataset, crucial for dataset distillation.
Parameters:
sorted_embeddings (list of Tensor): The sorted embeddings from the original dataset.
percentiles (list of float): The percentiles used to derive the quantiles.
Returns:
list: A list of quantiles derived from the embeddings.
"""
quantiles = []
n = len(sorted_embeddings) - 1 # Total number of records
for p in percentiles:
index = p / 100 * n # Calculate the unrounded index position
floor_index = int(np.floor(index))
ceil_index = int(np.ceil(index))
if floor_index == ceil_index:
quantile = sorted_embeddings[floor_index] # No averaging, select the record
else:
lower_ratio = ceil_index - index
upper_ratio = index - floor_index
quantile = (sorted_embeddings[floor_index] * lower_ratio) + (sorted_embeddings[ceil_index] * upper_ratio)
quantiles.append(quantile)
return quantiles
def find_nearest_embeddings(synthetic_embeddings, embeddings_dict):
"""
Find the nearest training embeddings for a list of synthetic embeddings using cosine similarity.
Args:
synthetic_embeddings (list of torch.Tensor): List of synthetic embeddings to compare.
embeddings_dict (list of torch.Tensor): List of training embeddings for comparison.
Returns:
list of torch.Tensor: List of nearest training embeddings for each synthetic embedding.
"""
nearest_embeddings_list = []
for synthetic_embedding in synthetic_embeddings:
similarities = []
for key, training_embedding in embeddings_dict.items():
# Calculate cosine similarity between synthetic and training embeddings
similarity = cosine_similarity(synthetic_embedding.detach().cpu().numpy().reshape(1, -1),
training_embedding.cpu().numpy().reshape(1, -1))[0][0]
similarities.append(similarity)
# Find the index of the nearest training embedding
nearest_idx = similarities.index(max(similarities))
# Get the actual nearest training embedding
nearest_embedding = embeddings_dict[nearest_idx]
nearest_embeddings_list.append(nearest_embedding)
return nearest_embeddings_list
# Function to get labels from logits
#same as generate_tokens_from_embeddings
def get_labels_from_logits(logits):
probs = torch.softmax(logits, dim=-1)
_, predicted_token_ids = torch.max(probs, dim=-1)
return [tokenizer.decode(token_id) for token_id in predicted_token_ids]
#same as get_labels_from_logits
def generate_tokens_from_embeddings(embeddings):
# Step 3: Calculate logits from embeddings
logits = model(inputs_embeds=embeddings).logits
# Step 4: Generate tokens from logits
generated_ids = torch.argmax(logits, dim=-1)
generated_tokens = tokenizer.batch_decode(generated_ids)
return generated_tokens
# Extract embeddings
train_embeddings = [extract_embeddings(text, tokenizer, model) for text in train_quotes]
#sorted_train_index = sort_embeddings_ema(train_embeddings)
sorted_train_index = sort_embeddings(train_embeddings)
sorted_train_dict = dict(zip(sorted_train_index,train_embeddings))
sorted_train_embeddings_dict = dict(zip(np.unique(list(sorted_train_dict.keys())),[sorted_train_dict[s] for s in np.unique(list(sorted_train_dict.keys()))]))
# Calculate the proportional quantiles based on the percentiles
# Example: Calculating for n=6 (similar to the Tukey's hinges distribution)
percentiles = gaussian_percentiles(NUM_DISTILLED_DATA-1)
# Extract embeddings
eval_embeddings = [extract_embeddings(text, tokenizer, model) for text in eval_quotes]
#sorted_eval_index = sort_embeddings_ema(eval_embeddings)
sorted_eval_index = sort_embeddings(eval_embeddings)
sorted_eval_dict = dict(zip(sorted_eval_index,eval_embeddings))
sorted_eval_embeddings_dict = dict(zip(np.unique(list(sorted_eval_dict.keys())),[sorted_eval_dict[s] for s in np.unique(list(sorted_eval_dict.keys()))]))
if(False):
train_ema_embeddings = [calculate_ema_embeddings(e) for e in train_embeddings]
train_center_ema_embeddings = np.mean(train_ema_embeddings,axis=0)
train_embeddings_sdevs = np.std(train_ema_embeddings,axis=0)
train_z_score_embeddings = embeddings_derive_z_scores(train_ema_embeddings,train_center_ema_embeddings,train_embeddings_sdevs)
# Convert each z-score in z_score_embeddings to percentiles
train_percentile_embeddings = [z_score_to_percentile(z_score) for z_score in train_z_score_embeddings]
# Sort the list based on the custom key function
sorted_training_percentiles = sorted(train_embeddings, key=custom_sort_key)
if(False):
train_index_to_rank = sorted_dict(sorted_training_percentiles, train_percentile_embeddings, train_embeddings)
eval_embeddings = [extract_embeddings(text, tokenizer, model) for text in eval_quotes]
eval_ema_embeddings = [calculate_ema_embeddings(e) for e in eval_embeddings]
eval_z_score_embeddings = embeddings_derive_z_scores(eval_ema_embeddings,train_center_ema_embeddings,train_embeddings_sdevs)
eval_percentile_embeddings = [z_score_to_percentile(z_score) for z_score in eval_z_score_embeddings]
sorted_eval_percentiles = np.sort(eval_percentile_embeddings)
eval_index_to_rank = sorted_dict(sorted_eval_percentiles, eval_percentile_embeddings, eval_embeddings)
# Initialize synthetic dataset
if(False):
synthetic_embeddings = np.percentile(train_embeddings,percentiles,axis=0)
#synthetic_embeddings = torch.randn(NUM_DISTILLED_DATA, DISTILLED_SEQ_LEN, EMBEDDING_SIZE, requires_grad=True)
#synthetic_z_score_embeddings = embeddings_derive_z_scores(synthetic_embeddings.detach().numpy(),train_center_ema_embeddings,train_embeddings_sdevs)
synthetic_z_score_embeddings = embeddings_derive_z_scores(synthetic_embeddings,train_center_ema_embeddings,train_embeddings_sdevs)
synthetic_percentile_embeddings = [z_score_to_percentile(z_score) for z_score in synthetic_z_score_embeddings]
sorted_synthetic_percentile_embeddings = np.sort(synthetic_percentile_embeddings)
sorted_synthetic_percentiles = np.sort(sorted_synthetic_percentile_embeddings)
synthetic_index_to_rank = sorted_dict(sorted_synthetic_percentiles, synthetic_percentile_embeddings, synthetic_embeddings)
# Optimization setup
custom_model = CustomGPTNeo.from_pretrained(model_name)
#already sorted!
synthetic_embeddings = derive_quantiles([sorted_train_embeddings_dict[i] for i in list(sorted_train_embeddings_dict.keys())],percentiles)
#if I were to use the dict
synthetic_embeddings_dict = dict(zip(range(0,len(percentiles)),synthetic_embeddings))
# Derive quantiles to create synthetic_embeddings
# Concatenate, squeeze, clone, detach, and set requires_grad to True
synthetic_embeddings = torch.cat(synthetic_embeddings, dim=0).squeeze(1).clone().detach().requires_grad_(True)
# Move synthetic_embeddings to the same device as the model
synthetic_embeddings = synthetic_embeddings.to(device)
nearest_embeddings_tensor = find_nearest_embeddings(synthetic_embeddings, sorted_train_embeddings_dict)
# Generate labels from logits
nearest_logit_labels = [generate_tokens_from_embeddings(l) for l in nearest_embeddings_tensor]
print(nearest_logit_labels)
#dict(zip(range(0,len(train_embeddings)),train_embeddings))
dict_embeddings_pos = dict(zip(train_embeddings,range(0,len(train_embeddings))))
dict_pos_quotes = dict(zip(range(0,len(train_quotes)),train_quotes))
nearest_input_strings = []
for e in nearest_embeddings_tensor:
position = dict_embeddings_pos[e]
nearest_input_strings.append(dict_pos_quotes[position])
print(nearest_input_strings)
# Generate a random sample from synthetic_labels
#random_sample = np.random.choice(synthetic_labels, 1)
# Print the random sample
#print(random_sample)
# Create an optimizer for the tensor
optimizer = torch.optim.Adam([synthetic_embeddings], lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
# Create the learning rate scheduler
lr_scheduler = LambdaLR(optimizer, lr_lambda)
# Move each tensor in the dictionary to GPU
for key, tensor in sorted_train_embeddings_dict.items():
sorted_train_embeddings_dict[key] = tensor.to(device)
for key, tensor in sorted_eval_embeddings_dict.items():
sorted_eval_embeddings_dict[key] = tensor.to(device)
wandb.init()
eval_loss = evaluate_model(custom_model, synthetic_embeddings, sorted_eval_embeddings_dict, NUM_DISTILLED_DATA)
print(f"Initial: Eval Loss {eval_loss}")
# Step 1: Calculate Train Logits and Store in a Dictionary
train_logits_dict = {} # Dictionary to store train logits
for idx, train_quote in enumerate(train_quotes):
# Derive embeddings for the current train_quote using your model
train_embedding = train_embeddings[idx] # Get the corresponding embedding
# Calculate logits from the embedding
train_logits = model(inputs_embeds=train_embedding).logits
# Store the logits in the dictionary with the quote's index as the key
train_logits_dict[idx] = train_logits
def find_nearest_embeddings_from_pairs(train_embeddings, train_logits_values, synthetic_logits):
# Convert train_logits_values to a list of tuples if necessary
# Assuming each element in train_logits_values is a tensor or a list/tuple that can be converted to a tensor
# Create a dictionary mapping from train logits pairs to train embeddings
logits_pairs_to_embeddings = dict(zip(train_logits_values, train_embeddings))
nearest_embeddings = []
for synthetic_logit in synthetic_logits:
nearest_embedding = None
#min_distance = float('inf')#.to(device)
min_distance = torch.tensor(float('inf'), device=device)
synthetic_logit = synthetic_logit.to(device)
for logits, embeddings in logits_pairs_to_embeddings.items():
# Convert train_logits_pair to tensor if necessary
train_logits_tensor = torch.tensor(logits) if not isinstance(logits, torch.Tensor) else logits
train_logits_tensor = train_logits_tensor.to(device)
# Calculate Euclidean distance
distance = torch.dist(synthetic_logit, train_logits_tensor, p=2)
if distance < min_distance:
min_distance = distance
nearest_embedding = embeddings
nearest_embeddings.append(nearest_embedding)
return nearest_embeddings
embedding_text_dict = dict(zip(train_embeddings,train_quotes))
# Distillation process
for epoch in range(NUM_EPOCHS):
shuffled_indices = random.sample(range(len(sorted_train_embeddings_dict)), len(sorted_train_embeddings_dict))
#total_loss = 0.0
num_batches = len(sorted_train_embeddings_dict) // NUM_DISTILLED_DATA
nearest_embeddings = find_nearest_embeddings(synthetic_embeddings, sorted_train_embeddings_dict)
# Step 1: Create a dictionary mapping index to embeddings
# Move nearest embeddings to the same device as the model
#nearest_embeddings = nearest_embeddings.to(device)
# Pass nearest embeddings through custom_model to get logits
#why would I want the logits of the nearest embeddings of synthetic data? To walk back from logits to embeddings?
#whether I index from embeddings -> logits or logits -> embeddings, the similarity score would be the same?
if(False):
nearest_logits = []
for tensor in nearest_embeddings:
output = custom_model(tensor)
nearest_logits.append(output)
#e_l_dict = dict(zip(nearest_embeddings,nearest_logits))
# Generate labels from logits
# Assuming nearest_logits is a list of torch.Tensor objects
#nearest_labels = [get_labels_from_logits(logits) for logits in nearest_logits]
# Move synthetic_embeddings to the same device as the model
synthetic_embeddings = synthetic_embeddings.to(device)
synthetic_logits = custom_model(synthetic_embeddings)
nearest_embeddings_logits = find_nearest_embeddings_from_pairs(train_embeddings, train_logits_dict.values(), synthetic_logits)
synthetic_logits_embeddings_text = [embedding_text_dict[e] for e in nearest_embeddings_logits]
# Step 3: Calculate Cosine Similarity and Find Nearest Train Quote
nearest_train_quotes = []
for synthetic_logit in synthetic_logits:
similarities = {}
# Calculate cosine similarity with each set of train logits
for idx, train_logit in train_logits_dict.items():
similarity = cosine_similarity(synthetic_logit.detach().cpu().numpy().reshape(1, -1),
train_logit.detach().cpu().numpy().reshape(1, -1))[0][0]
similarities[idx] = similarity
# Find the index of the nearest train quote (highest similarity)
nearest_idx = max(similarities, key=similarities.get)
# Retrieve the actual nearest train quote
nearest_train_quote = train_quotes[nearest_idx]
nearest_train_quotes.append(nearest_train_quote)
optimizer.zero_grad()
nearest_embeddings_tensor = find_nearest_embeddings(synthetic_embeddings, sorted_train_embeddings_dict)
# Generate labels from logits
nearest_logit_labels = [generate_tokens_from_embeddings(l) for l in nearest_embeddings_tensor]
nearest_input_strings = []
for e in nearest_embeddings_tensor:
position = dict_embeddings_pos[e]
nearest_input_strings.append(dict_pos_quotes[position])
#print(nearest_text)
print(synthetic_logits_embeddings_text)
print(nearest_train_quotes)
print(nearest_input_strings)
print(nearest_logit_labels)
batch_indices = []
batch_train_embeddings_list = [] # Store batch embeddings
for batch_idx in range(num_batches):
sorted_batch_indices = np.unique(shuffled_indices[batch_idx * NUM_DISTILLED_DATA:(batch_idx + 1) * NUM_DISTILLED_DATA])
batch_indices.append(sorted_batch_indices)
sorted_batch_embeddings = [sorted_train_embeddings_dict[i] for i in sorted_batch_indices]
batch_train_embeddings = torch.cat(sorted_batch_embeddings, dim=0)
batch_train_embeddings = batch_train_embeddings.squeeze(1)
batch_train_embeddings_list.append(batch_train_embeddings)
# Concatenate batch embeddings into a single tensor
all_batch_train_embeddings = torch.cat(batch_train_embeddings_list, dim=0)
# Pass all batch embeddings through custom_model to get logits
all_batch_train_logits = custom_model(all_batch_train_embeddings)
# Concatenate synthetic_logits to itself to match the size of all_batch_train_logits
synthetic_logits_expanded = torch.cat([synthetic_logits] * num_batches, dim=0)
# Calculate the KL-divergence loss
loss = F.kl_div(F.log_softmax(synthetic_logits_expanded, dim=-1), F.softmax(all_batch_train_logits, dim=-1), reduction='batchmean')
# Calculate the average loss for the epoch
#avg_training_loss = loss / num_batches
#total_loss += loss.item() # Accumulate the loss
# Perform backward and optimizer step after processing all batches
loss.backward()
optimizer.step()
lr_scheduler.step()
current_lr = optimizer.param_groups[0]['lr']
# Evaluate at the end of the epoch
eval_loss = evaluate_model(custom_model, synthetic_embeddings, sorted_eval_embeddings_dict, NUM_DISTILLED_DATA)
# Create a formatted string that includes epoch, learning rate, training loss, and evaluation loss
log_string = f"Epoch {epoch}: LR {current_lr:.10f}, Train Loss {loss:.3f}, Eval Loss {eval_loss:.3f}"
# Print the log string
print(log_string)
# Log the log string to wandb
wandb.log({"epoch_log": log_string})
# Finish the wandb session
wandb.finish()
# Save synthetic dataset
#this converts the logits to text
pd.DataFrame([' '.join(t) for t in synthetic_logits_embeddings_text]).to_csv('synthetic_logits_embeddings_text.csv')
pd.DataFrame([' '.join(t) for t in nearest_train_quotes]).to_csv('nearest_train_quotes.csv')
pd.DataFrame([' '.join(t) for t in nearest_input_strings]).to_csv('nearest_input_strings.csv')
pd.DataFrame([' '.join(t) for t in nearest_logit_labels]).to_csv('nearest_logit_labels.csv')
#this finds the closest embeddings to logits using euclidian distance between paired vectors of different size.
#pd.DataFrame([' '.join(t) for t in synthetic_logit_labels]).to_csv('synthetic_logit_labels.csv')
#pd.DataFrame([' '.join(t) for t in nearest_logit_labels]).to_csv('nearest_logit_labels.csv')
#pd.DataFrame([' '.join(t) for t in nearest_text]).to_csv('nearest_text.csv')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment