Last active
December 9, 2023 23:10
-
-
Save thistleknot/57ef8e63de0f2b6df5753e7d7b305864 to your computer and use it in GitHub Desktop.
dataset distillation v5
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# coding: utf-8 | |
import torch | |
import torch.nn.functional as F | |
from transformers import GPTNeoForCausalLM, AutoTokenizer | |
from datasets import load_dataset | |
from sklearn.model_selection import train_test_split | |
import pandas as pd | |
import numpy as np | |
import random | |
from scipy import stats | |
import math | |
import wandb | |
import os | |
from torch.optim.lr_scheduler import LambdaLR | |
import torch.nn.functional as F | |
from sklearn.metrics.pairwise import cosine_similarity | |
os.environ["WANDB_MODE"]="offline" | |
if torch.cuda.is_available(): | |
print("CUDA is available. Using GPU.") | |
device = torch.device("cuda") | |
else: | |
print("CUDA is not available. Using CPU.") | |
device = torch.device("cpu") | |
# Parameters | |
#has to be bigger than warmup_steps | |
NUM_EPOCHS = 10 | |
NUM_DISTILLED_DATA = 13 # Number of synthetic data points | |
DISTILLED_SEQ_LEN = 21 # Length of sequences for distillation | |
EVAL_INTERVAL = 1 | |
# Parameters | |
LEARNING_RATE = 1e-3 | |
peak_lr = LEARNING_RATE*2 | |
desired_lr = LEARNING_RATE / 20 | |
WEIGHT_DECAY=int(NUM_EPOCHS/10) | |
# Add a warm-up phase | |
warmup_epochs = 5 # Adjust as needed | |
# Load GPT-Neo model and tokenizer | |
model_name = "EleutherAI/gpt-neo-125M" | |
model = GPTNeoForCausalLM.from_pretrained(model_name) | |
model.to(device) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
tokenizer.pad_token = tokenizer.eos_token | |
# Load and preprocess dataset | |
dataset = load_dataset("Abirate/english_quotes") | |
quotes = [item['quote'] for item in dataset['train']] | |
filtered_quotes = [q for q in quotes if len(tokenizer.encode(q, truncation=True)) == DISTILLED_SEQ_LEN] | |
train_quotes, eval_quotes = train_test_split(filtered_quotes, test_size=0.2, random_state=42) | |
# Create a lambda function for the warm-up and cosine annealing schedule | |
def lr_lambda(epoch): | |
if epoch < warmup_epochs: | |
return LEARNING_RATE + (peak_lr - LEARNING_RATE) * (epoch + 1) / warmup_epochs | |
else: | |
return desired_lr + 0.5 * (peak_lr - desired_lr) * (1 + math.cos(math.pi * (epoch - warmup_epochs) / (NUM_EPOCHS - warmup_epochs))) | |
#modify sequence length to measure | |
def calculate_ema_embeddings(sequence, sequence_length=21): | |
""" | |
Calculate the Exponential Moving Average (EMA) embeddings for a given sequence. | |
This function processes an input sequence to compute its EMA embeddings, which are used | |
to capture the evolving nature of embeddings in a sequence. | |
Parameters: | |
sequence (Tensor): The sequence of embeddings. | |
sequence_length (int): The length of the sequence for EMA calculation. | |
Returns: | |
Tensor: The EMA embedding of the input sequence. | |
""" | |
# Remove the batch dimension assuming it's always 1 | |
sequence = sequence.squeeze(0) # shape now [sequence_length, embedding_size] | |
# Initialize EMA with the first token's embedding | |
alpha = 2.0 / (sequence_length + 1) | |
ema_embedding = sequence[0] | |
for i in range(1, sequence_length): | |
weight = 1 if i == sequence_length - 1 else alpha | |
ema_embedding = weight * sequence[i] + (1 - weight) * ema_embedding | |
return ema_embedding | |
def gaussian_percentiles(n): | |
""" | |
Generate n evenly spaced percentiles for a Gaussian distribution. | |
This function is used to determine the percentiles that correspond to a normal distribution, | |
which is crucial for aligning the distilled dataset with the statistical distribution of the original data. | |
Parameters: | |
n (int): The number of divisions in the Gaussian curve. | |
Returns: | |
list: A list of percentile values mapped onto a Gaussian distribution. | |
""" | |
# Evenly spaced percentiles | |
percentiles = np.linspace(0, 100, n+1) | |
# Mapping these percentiles to the Gaussian distribution | |
gaussian_values = [stats.norm.cdf(stats.norm.ppf(p/100)) * 100 for p in percentiles] | |
return gaussian_values | |
def extract_embeddings(input_tokens, tokenizer, model): | |
""" | |
Generate tokens from embeddings using GPT-Neo model. | |
Args: | |
model (GPTNeoForCausalLM): GPT-Neo model. | |
tokenizer (GPT2Tokenizer): Tokenizer for GPT-Neo. | |
input_tokens (list): List of input tokens. | |
Returns: | |
list: List of generated tokens. | |
""" | |
# Step 1: Convert input tokens to input IDs | |
input_ids = tokenizer(input_tokens, return_tensors='pt').input_ids.to(device) | |
# Step 2: Get embeddings for input IDs | |
embeddings = model.get_input_embeddings()(input_ids) | |
#inputs = tokenizer(data, return_tensors="pt", padding=True, truncation=True) | |
#inputs = {k: v.to(device) for k, v in inputs.items()} # Move inputs to the device | |
#outputs = model(**inputs, output_hidden_states=True) | |
#return outputs.hidden_states[-1].detach() | |
return embeddings.detach() | |
class CustomGPTNeo(GPTNeoForCausalLM): | |
def forward(self, embeddings, labels=None): | |
# Move embeddings to the same device as the model | |
embeddings = embeddings.to(self.transformer.wte.weight.device) | |
# Forward pass | |
transformer_outputs = self.transformer(inputs_embeds=embeddings, return_dict=True) | |
hidden_states = transformer_outputs.last_hidden_state | |
lm_logits = self.lm_head(hidden_states) | |
return lm_logits | |
# Fluctuating Evaluation function: alternative would be to derive quantiles over all the eval data and have a static quantile set. | |
def evaluate_model(model, synthetic_data, sorted_eval_embeddings_dict, num_samples): | |
""" | |
Evaluate the model using synthetic data and a dictionary of sorted evaluation embeddings. | |
This function measures the performance of the distilled dataset against a subset of evaluation data. | |
It's crucial for assessing how well the synthetic dataset mimics the original data distribution. | |
Parameters: | |
model (Model): The model to be evaluated. | |
synthetic_data (Tensor): The synthetic data used for evaluation. | |
sorted_eval_embeddings_dict (dict): A dictionary of sorted evaluation embeddings. | |
num_samples (int): The number of samples to use for evaluation. | |
Returns: | |
float: The evaluation loss. | |
""" | |
sorted_sampled_indices = np.unique(random.sample(range(len(sorted_eval_embeddings_dict)), num_samples)) | |
sampled_eval_embeddings = torch.stack([sorted_eval_embeddings_dict[i] for i in sorted_sampled_indices]) | |
sampled_eval_embeddings = sampled_eval_embeddings.squeeze(1) | |
with torch.no_grad(): | |
synthetic_logits = model(synthetic_data) | |
eval_logits = model(sampled_eval_embeddings) | |
eval_loss = F.kl_div(F.log_softmax(synthetic_logits, dim=-1), F.softmax(eval_logits, dim=-1), reduction='batchmean') | |
return eval_loss.item() | |
def embeddings_derive_z_scores(embeddings, centers, sdevs): | |
delta_center_embeddings = [e-centers for e in embeddings] | |
return delta_center_embeddings/sdevs | |
# Function to convert z-scores to percentiles | |
def z_score_to_percentile(z_scores): | |
# Apply the CDF of the standard normal distribution to each z-score | |
percentiles = stats.norm.cdf(z_scores) * 100 # Multiply by 100 to get percentiles | |
return percentiles | |
# Function to sort embeddings based on their quantile rankings and create index mapping | |
def sort_embeddings_and_create_mapping(percentiles): | |
sorted_indices = np.argsort(percentiles, axis=0) | |
return percentiles[sorted_indices], {original: sorted for sorted, original in enumerate(sorted_indices)} | |
def sorted_dict(sorted, percents, embeddings): | |
# Step 2: Find the position of each original embedding in the sorted list | |
index_to_rank = {} | |
means = [np.mean(s) for s in sorted] | |
for rank in range(len(sorted)): | |
rank_mean = np.mean(sorted[rank]) | |
position = np.where(means==rank_mean)[0][0] | |
index_to_rank[rank] = embeddings[position] | |
return index_to_rank | |
def align_embeddings(embeddings, index_to_rank): | |
aligned_embeddings = [embeddings[index] for index in index_to_rank] | |
return torch.stack(aligned_embeddings) | |
def sort_embeddings(embeddings): | |
dataframes = [] | |
for t in embeddings: | |
# First ensure the tensor is on CPU | |
t_cpu = t[0].cpu() | |
# Now convert the CPU tensor to a DataFrame | |
dataframes.append(pd.DataFrame(t_cpu.numpy())) | |
rows = pd.DataFrame() | |
for d in range(0,len(dataframes)): | |
df = dataframes[d] | |
collapsed_df = pd.melt(df, value_vars=df.columns, value_name='Value') | |
collapsed_df = collapsed_df[['Value']] | |
collapsed_df.reset_index(drop=True, inplace=True) | |
collapsed_df.columns = [d] | |
rows = pd.concat([rows,collapsed_df.T],axis=0) | |
sorted_embeddings = rows.sort_values(by=list(rows.columns)) | |
return(sorted_embeddings.index) | |
def sort_embeddings(embeddings): | |
averages = [] | |
for t in embeddings: | |
# Ensure the tensor is on CPU and convert it to a DataFrame | |
t_cpu = t[0].cpu() | |
df = pd.DataFrame(t_cpu.numpy()) | |
# Compute the mean of means for each embedding | |
mean_of_means = df.mean(axis=1).mean() | |
averages.append(mean_of_means) | |
# Create a Series from the averages | |
average_series = pd.Series(averages) | |
# Get the sorted indices | |
sorted_indices = average_series.sort_values().index | |
return sorted_indices | |
def calculate_ema(series, alpha): | |
""" | |
Calculate the Exponential Moving Average (EMA) for a given series. | |
Parameters: | |
series (pd.Series): The series of values for which EMA is to be calculated. | |
alpha (float): The smoothing factor for EMA. | |
Returns: | |
float: The EMA of the input series. | |
""" | |
ema = series[0] | |
for value in series[1:]: | |
ema = alpha * value + (1 - alpha) * ema | |
return ema | |
def sort_embeddings_ema(embeddings): | |
ema_values = [] | |
for t in embeddings: | |
# Ensure the tensor is on CPU and convert it to a DataFrame | |
t_cpu = t[0].cpu() | |
df = pd.DataFrame(t_cpu.numpy()) | |
# Calculate alpha for EMA | |
n = len(df) | |
alpha = (n - 1) / n | |
# Compute the EMA for each embedding | |
ema = df.apply(lambda row: calculate_ema(row, alpha), axis=1) | |
mean_ema = ema.mean() | |
ema_values.append(mean_ema) | |
# Create a Series from the EMA values | |
ema_series = pd.Series(ema_values) | |
# Get the sorted indices | |
sorted_indices = ema_series.sort_values().index | |
return sorted_indices | |
def derive_quantiles(sorted_embeddings, percentiles): | |
""" | |
Derive quantiles from sorted embeddings based on specified percentiles. | |
This function aligns the distilled dataset with quantiles to ensure that it represents | |
the statistical characteristics of the original dataset, crucial for dataset distillation. | |
Parameters: | |
sorted_embeddings (list of Tensor): The sorted embeddings from the original dataset. | |
percentiles (list of float): The percentiles used to derive the quantiles. | |
Returns: | |
list: A list of quantiles derived from the embeddings. | |
""" | |
quantiles = [] | |
n = len(sorted_embeddings) - 1 # Total number of records | |
for p in percentiles: | |
index = p / 100 * n # Calculate the unrounded index position | |
floor_index = int(np.floor(index)) | |
ceil_index = int(np.ceil(index)) | |
if floor_index == ceil_index: | |
quantile = sorted_embeddings[floor_index] # No averaging, select the record | |
else: | |
lower_ratio = ceil_index - index | |
upper_ratio = index - floor_index | |
quantile = (sorted_embeddings[floor_index] * lower_ratio) + (sorted_embeddings[ceil_index] * upper_ratio) | |
quantiles.append(quantile) | |
return quantiles | |
def find_nearest_embeddings(synthetic_embeddings, embeddings_dict): | |
""" | |
Find the nearest training embeddings for a list of synthetic embeddings using cosine similarity. | |
Args: | |
synthetic_embeddings (list of torch.Tensor): List of synthetic embeddings to compare. | |
embeddings_dict (list of torch.Tensor): List of training embeddings for comparison. | |
Returns: | |
list of torch.Tensor: List of nearest training embeddings for each synthetic embedding. | |
""" | |
nearest_embeddings_list = [] | |
for synthetic_embedding in synthetic_embeddings: | |
similarities = [] | |
for key, training_embedding in embeddings_dict.items(): | |
# Calculate cosine similarity between synthetic and training embeddings | |
similarity = cosine_similarity(synthetic_embedding.detach().cpu().numpy().reshape(1, -1), | |
training_embedding.cpu().numpy().reshape(1, -1))[0][0] | |
similarities.append(similarity) | |
# Find the index of the nearest training embedding | |
nearest_idx = similarities.index(max(similarities)) | |
# Get the actual nearest training embedding | |
nearest_embedding = embeddings_dict[nearest_idx] | |
nearest_embeddings_list.append(nearest_embedding) | |
return nearest_embeddings_list | |
# Function to get labels from logits | |
#same as generate_tokens_from_embeddings | |
def get_labels_from_logits(logits): | |
probs = torch.softmax(logits, dim=-1) | |
_, predicted_token_ids = torch.max(probs, dim=-1) | |
return [tokenizer.decode(token_id) for token_id in predicted_token_ids] | |
#same as get_labels_from_logits | |
def generate_tokens_from_embeddings(embeddings): | |
# Step 3: Calculate logits from embeddings | |
logits = model(inputs_embeds=embeddings).logits | |
# Step 4: Generate tokens from logits | |
generated_ids = torch.argmax(logits, dim=-1) | |
generated_tokens = tokenizer.batch_decode(generated_ids) | |
return generated_tokens | |
# Extract embeddings | |
train_embeddings = [extract_embeddings(text, tokenizer, model) for text in train_quotes] | |
#sorted_train_index = sort_embeddings_ema(train_embeddings) | |
sorted_train_index = sort_embeddings(train_embeddings) | |
sorted_train_dict = dict(zip(sorted_train_index,train_embeddings)) | |
sorted_train_embeddings_dict = dict(zip(np.unique(list(sorted_train_dict.keys())),[sorted_train_dict[s] for s in np.unique(list(sorted_train_dict.keys()))])) | |
# Calculate the proportional quantiles based on the percentiles | |
# Example: Calculating for n=6 (similar to the Tukey's hinges distribution) | |
percentiles = gaussian_percentiles(NUM_DISTILLED_DATA-1) | |
# Extract embeddings | |
eval_embeddings = [extract_embeddings(text, tokenizer, model) for text in eval_quotes] | |
#sorted_eval_index = sort_embeddings_ema(eval_embeddings) | |
sorted_eval_index = sort_embeddings(eval_embeddings) | |
sorted_eval_dict = dict(zip(sorted_eval_index,eval_embeddings)) | |
sorted_eval_embeddings_dict = dict(zip(np.unique(list(sorted_eval_dict.keys())),[sorted_eval_dict[s] for s in np.unique(list(sorted_eval_dict.keys()))])) | |
if(False): | |
train_ema_embeddings = [calculate_ema_embeddings(e) for e in train_embeddings] | |
train_center_ema_embeddings = np.mean(train_ema_embeddings,axis=0) | |
train_embeddings_sdevs = np.std(train_ema_embeddings,axis=0) | |
train_z_score_embeddings = embeddings_derive_z_scores(train_ema_embeddings,train_center_ema_embeddings,train_embeddings_sdevs) | |
# Convert each z-score in z_score_embeddings to percentiles | |
train_percentile_embeddings = [z_score_to_percentile(z_score) for z_score in train_z_score_embeddings] | |
# Sort the list based on the custom key function | |
sorted_training_percentiles = sorted(train_embeddings, key=custom_sort_key) | |
if(False): | |
train_index_to_rank = sorted_dict(sorted_training_percentiles, train_percentile_embeddings, train_embeddings) | |
eval_embeddings = [extract_embeddings(text, tokenizer, model) for text in eval_quotes] | |
eval_ema_embeddings = [calculate_ema_embeddings(e) for e in eval_embeddings] | |
eval_z_score_embeddings = embeddings_derive_z_scores(eval_ema_embeddings,train_center_ema_embeddings,train_embeddings_sdevs) | |
eval_percentile_embeddings = [z_score_to_percentile(z_score) for z_score in eval_z_score_embeddings] | |
sorted_eval_percentiles = np.sort(eval_percentile_embeddings) | |
eval_index_to_rank = sorted_dict(sorted_eval_percentiles, eval_percentile_embeddings, eval_embeddings) | |
# Initialize synthetic dataset | |
if(False): | |
synthetic_embeddings = np.percentile(train_embeddings,percentiles,axis=0) | |
#synthetic_embeddings = torch.randn(NUM_DISTILLED_DATA, DISTILLED_SEQ_LEN, EMBEDDING_SIZE, requires_grad=True) | |
#synthetic_z_score_embeddings = embeddings_derive_z_scores(synthetic_embeddings.detach().numpy(),train_center_ema_embeddings,train_embeddings_sdevs) | |
synthetic_z_score_embeddings = embeddings_derive_z_scores(synthetic_embeddings,train_center_ema_embeddings,train_embeddings_sdevs) | |
synthetic_percentile_embeddings = [z_score_to_percentile(z_score) for z_score in synthetic_z_score_embeddings] | |
sorted_synthetic_percentile_embeddings = np.sort(synthetic_percentile_embeddings) | |
sorted_synthetic_percentiles = np.sort(sorted_synthetic_percentile_embeddings) | |
synthetic_index_to_rank = sorted_dict(sorted_synthetic_percentiles, synthetic_percentile_embeddings, synthetic_embeddings) | |
# Optimization setup | |
custom_model = CustomGPTNeo.from_pretrained(model_name) | |
#already sorted! | |
synthetic_embeddings = derive_quantiles([sorted_train_embeddings_dict[i] for i in list(sorted_train_embeddings_dict.keys())],percentiles) | |
#if I were to use the dict | |
synthetic_embeddings_dict = dict(zip(range(0,len(percentiles)),synthetic_embeddings)) | |
# Derive quantiles to create synthetic_embeddings | |
# Concatenate, squeeze, clone, detach, and set requires_grad to True | |
synthetic_embeddings = torch.cat(synthetic_embeddings, dim=0).squeeze(1).clone().detach().requires_grad_(True) | |
# Move synthetic_embeddings to the same device as the model | |
synthetic_embeddings = synthetic_embeddings.to(device) | |
nearest_embeddings_tensor = find_nearest_embeddings(synthetic_embeddings, sorted_train_embeddings_dict) | |
# Generate labels from logits | |
nearest_logit_labels = [generate_tokens_from_embeddings(l) for l in nearest_embeddings_tensor] | |
print(nearest_logit_labels) | |
#dict(zip(range(0,len(train_embeddings)),train_embeddings)) | |
dict_embeddings_pos = dict(zip(train_embeddings,range(0,len(train_embeddings)))) | |
dict_pos_quotes = dict(zip(range(0,len(train_quotes)),train_quotes)) | |
nearest_input_strings = [] | |
for e in nearest_embeddings_tensor: | |
position = dict_embeddings_pos[e] | |
nearest_input_strings.append(dict_pos_quotes[position]) | |
print(nearest_input_strings) | |
# Generate a random sample from synthetic_labels | |
#random_sample = np.random.choice(synthetic_labels, 1) | |
# Print the random sample | |
#print(random_sample) | |
# Create an optimizer for the tensor | |
optimizer = torch.optim.Adam([synthetic_embeddings], lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY) | |
# Create the learning rate scheduler | |
lr_scheduler = LambdaLR(optimizer, lr_lambda) | |
# Move each tensor in the dictionary to GPU | |
for key, tensor in sorted_train_embeddings_dict.items(): | |
sorted_train_embeddings_dict[key] = tensor.to(device) | |
for key, tensor in sorted_eval_embeddings_dict.items(): | |
sorted_eval_embeddings_dict[key] = tensor.to(device) | |
wandb.init() | |
eval_loss = evaluate_model(custom_model, synthetic_embeddings, sorted_eval_embeddings_dict, NUM_DISTILLED_DATA) | |
print(f"Initial: Eval Loss {eval_loss}") | |
# Step 1: Calculate Train Logits and Store in a Dictionary | |
train_logits_dict = {} # Dictionary to store train logits | |
for idx, train_quote in enumerate(train_quotes): | |
# Derive embeddings for the current train_quote using your model | |
train_embedding = train_embeddings[idx] # Get the corresponding embedding | |
# Calculate logits from the embedding | |
train_logits = model(inputs_embeds=train_embedding).logits | |
# Store the logits in the dictionary with the quote's index as the key | |
train_logits_dict[idx] = train_logits | |
def find_nearest_embeddings_from_pairs(train_embeddings, train_logits_values, synthetic_logits): | |
# Convert train_logits_values to a list of tuples if necessary | |
# Assuming each element in train_logits_values is a tensor or a list/tuple that can be converted to a tensor | |
# Create a dictionary mapping from train logits pairs to train embeddings | |
logits_pairs_to_embeddings = dict(zip(train_logits_values, train_embeddings)) | |
nearest_embeddings = [] | |
for synthetic_logit in synthetic_logits: | |
nearest_embedding = None | |
#min_distance = float('inf')#.to(device) | |
min_distance = torch.tensor(float('inf'), device=device) | |
synthetic_logit = synthetic_logit.to(device) | |
for logits, embeddings in logits_pairs_to_embeddings.items(): | |
# Convert train_logits_pair to tensor if necessary | |
train_logits_tensor = torch.tensor(logits) if not isinstance(logits, torch.Tensor) else logits | |
train_logits_tensor = train_logits_tensor.to(device) | |
# Calculate Euclidean distance | |
distance = torch.dist(synthetic_logit, train_logits_tensor, p=2) | |
if distance < min_distance: | |
min_distance = distance | |
nearest_embedding = embeddings | |
nearest_embeddings.append(nearest_embedding) | |
return nearest_embeddings | |
embedding_text_dict = dict(zip(train_embeddings,train_quotes)) | |
# Distillation process | |
for epoch in range(NUM_EPOCHS): | |
shuffled_indices = random.sample(range(len(sorted_train_embeddings_dict)), len(sorted_train_embeddings_dict)) | |
#total_loss = 0.0 | |
num_batches = len(sorted_train_embeddings_dict) // NUM_DISTILLED_DATA | |
nearest_embeddings = find_nearest_embeddings(synthetic_embeddings, sorted_train_embeddings_dict) | |
# Step 1: Create a dictionary mapping index to embeddings | |
# Move nearest embeddings to the same device as the model | |
#nearest_embeddings = nearest_embeddings.to(device) | |
# Pass nearest embeddings through custom_model to get logits | |
#why would I want the logits of the nearest embeddings of synthetic data? To walk back from logits to embeddings? | |
#whether I index from embeddings -> logits or logits -> embeddings, the similarity score would be the same? | |
if(False): | |
nearest_logits = [] | |
for tensor in nearest_embeddings: | |
output = custom_model(tensor) | |
nearest_logits.append(output) | |
#e_l_dict = dict(zip(nearest_embeddings,nearest_logits)) | |
# Generate labels from logits | |
# Assuming nearest_logits is a list of torch.Tensor objects | |
#nearest_labels = [get_labels_from_logits(logits) for logits in nearest_logits] | |
# Move synthetic_embeddings to the same device as the model | |
synthetic_embeddings = synthetic_embeddings.to(device) | |
synthetic_logits = custom_model(synthetic_embeddings) | |
nearest_embeddings_logits = find_nearest_embeddings_from_pairs(train_embeddings, train_logits_dict.values(), synthetic_logits) | |
synthetic_logits_embeddings_text = [embedding_text_dict[e] for e in nearest_embeddings_logits] | |
# Step 3: Calculate Cosine Similarity and Find Nearest Train Quote | |
nearest_train_quotes = [] | |
for synthetic_logit in synthetic_logits: | |
similarities = {} | |
# Calculate cosine similarity with each set of train logits | |
for idx, train_logit in train_logits_dict.items(): | |
similarity = cosine_similarity(synthetic_logit.detach().cpu().numpy().reshape(1, -1), | |
train_logit.detach().cpu().numpy().reshape(1, -1))[0][0] | |
similarities[idx] = similarity | |
# Find the index of the nearest train quote (highest similarity) | |
nearest_idx = max(similarities, key=similarities.get) | |
# Retrieve the actual nearest train quote | |
nearest_train_quote = train_quotes[nearest_idx] | |
nearest_train_quotes.append(nearest_train_quote) | |
optimizer.zero_grad() | |
nearest_embeddings_tensor = find_nearest_embeddings(synthetic_embeddings, sorted_train_embeddings_dict) | |
# Generate labels from logits | |
nearest_logit_labels = [generate_tokens_from_embeddings(l) for l in nearest_embeddings_tensor] | |
nearest_input_strings = [] | |
for e in nearest_embeddings_tensor: | |
position = dict_embeddings_pos[e] | |
nearest_input_strings.append(dict_pos_quotes[position]) | |
#print(nearest_text) | |
print(synthetic_logits_embeddings_text) | |
print(nearest_train_quotes) | |
print(nearest_input_strings) | |
print(nearest_logit_labels) | |
batch_indices = [] | |
batch_train_embeddings_list = [] # Store batch embeddings | |
for batch_idx in range(num_batches): | |
sorted_batch_indices = np.unique(shuffled_indices[batch_idx * NUM_DISTILLED_DATA:(batch_idx + 1) * NUM_DISTILLED_DATA]) | |
batch_indices.append(sorted_batch_indices) | |
sorted_batch_embeddings = [sorted_train_embeddings_dict[i] for i in sorted_batch_indices] | |
batch_train_embeddings = torch.cat(sorted_batch_embeddings, dim=0) | |
batch_train_embeddings = batch_train_embeddings.squeeze(1) | |
batch_train_embeddings_list.append(batch_train_embeddings) | |
# Concatenate batch embeddings into a single tensor | |
all_batch_train_embeddings = torch.cat(batch_train_embeddings_list, dim=0) | |
# Pass all batch embeddings through custom_model to get logits | |
all_batch_train_logits = custom_model(all_batch_train_embeddings) | |
# Concatenate synthetic_logits to itself to match the size of all_batch_train_logits | |
synthetic_logits_expanded = torch.cat([synthetic_logits] * num_batches, dim=0) | |
# Calculate the KL-divergence loss | |
loss = F.kl_div(F.log_softmax(synthetic_logits_expanded, dim=-1), F.softmax(all_batch_train_logits, dim=-1), reduction='batchmean') | |
# Calculate the average loss for the epoch | |
#avg_training_loss = loss / num_batches | |
#total_loss += loss.item() # Accumulate the loss | |
# Perform backward and optimizer step after processing all batches | |
loss.backward() | |
optimizer.step() | |
lr_scheduler.step() | |
current_lr = optimizer.param_groups[0]['lr'] | |
# Evaluate at the end of the epoch | |
eval_loss = evaluate_model(custom_model, synthetic_embeddings, sorted_eval_embeddings_dict, NUM_DISTILLED_DATA) | |
# Create a formatted string that includes epoch, learning rate, training loss, and evaluation loss | |
log_string = f"Epoch {epoch}: LR {current_lr:.10f}, Train Loss {loss:.3f}, Eval Loss {eval_loss:.3f}" | |
# Print the log string | |
print(log_string) | |
# Log the log string to wandb | |
wandb.log({"epoch_log": log_string}) | |
# Finish the wandb session | |
wandb.finish() | |
# Save synthetic dataset | |
#this converts the logits to text | |
pd.DataFrame([' '.join(t) for t in synthetic_logits_embeddings_text]).to_csv('synthetic_logits_embeddings_text.csv') | |
pd.DataFrame([' '.join(t) for t in nearest_train_quotes]).to_csv('nearest_train_quotes.csv') | |
pd.DataFrame([' '.join(t) for t in nearest_input_strings]).to_csv('nearest_input_strings.csv') | |
pd.DataFrame([' '.join(t) for t in nearest_logit_labels]).to_csv('nearest_logit_labels.csv') | |
#this finds the closest embeddings to logits using euclidian distance between paired vectors of different size. | |
#pd.DataFrame([' '.join(t) for t in synthetic_logit_labels]).to_csv('synthetic_logit_labels.csv') | |
#pd.DataFrame([' '.join(t) for t in nearest_logit_labels]).to_csv('nearest_logit_labels.csv') | |
#pd.DataFrame([' '.join(t) for t in nearest_text]).to_csv('nearest_text.csv') | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment