Created
October 8, 2023 04:25
-
-
Save JacksonCakes/54e17dd4bf08a14987a12dc62fc968de to your computer and use it in GitHub Desktop.
simple visualization on the shifting pattern of LongLoRA
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
group_size = 2048 | |
shift_amount = -group_size // 2 | |
boundary = 1024 # The boundary where the rolling operation causes a transition | |
# Select a head from the last four heads (head number 7 in this case) | |
selected_head = 7 | |
# initialize a 2D tensor with random values | |
np.random.seed(0) # Set seed for reproducibility | |
original_random_tensor = np.random.rand(8192, 8) | |
# perform rolling operation on the last 4 heads, shifting them left by 1024 positions | |
transformed_random_tensor = original_random_tensor.copy() | |
transformed_random_tensor[:, 4:] = np.roll(original_random_tensor[:, 4:], shift_amount, axis=0) | |
# Extract values along the sequence for the selected head from both tensors | |
original_head_values = original_random_tensor[:, selected_head] | |
transformed_head_values = transformed_random_tensor[:, selected_head] | |
# plot values along the sequence for the selected head | |
fig, axs = plt.subplots(2, 1, figsize=(12, 6), sharex=True) | |
# Plot for original tensor | |
axs[0].plot(original_head_values, label=f"Head {selected_head} - Original", color='blue') | |
axs[0].set_title(f"Values Along Sequence for Head {selected_head} - Original Tensor") | |
axs[0].legend() | |
# Plot for transformed tensor | |
axs[1].plot(transformed_head_values, label=f"Head {selected_head} - Transformed", color='red') | |
axs[1].set_title(f"Values Along Sequence for Head {selected_head} - Transformed Tensor") | |
axs[1].legend() | |
# Highlight the boundary where the shift occurs | |
for ax in axs: | |
ax.axvline(x=boundary, color='yellow', linestyle='--', label=f"Shift Boundary ({boundary})") | |
plt.xlabel("Token Position in Sequence") | |
plt.tight_layout() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment