Skip to content

Instantly share code, notes, and snippets.

@dfaker
Created September 28, 2022 22:21
Show Gist options
  • Save dfaker/664a9cd8d2cc391c42c2a612955a3fbf to your computer and use it in GitHub Desktop.
Save dfaker/664a9cd8d2cc391c42c2a612955a3fbf to your computer and use it in GitHub Desktop.
import numpy as np
import torch
def lerp(theta0, theta1, alpha):
return (1 - alpha) * theta0 + alpha * theta1
def slerp(theta0, theta1, alpha):
theta0 = theta0
theta1 = theta1
# Copy the vectors to reuse them later
theta0_copy = torch.clone(theta0)
theta1_copy = torch.clone(theta1)
# Normalize the vectors to get the directions and angles
theta0 = theta0 / np.linalg.norm(theta0)
theta1 = theta1 / np.linalg.norm(theta1)
# Dot product with the normalized vectors (can't use np.dot in W)
dot = torch.sum(theta0 * theta1)
if np.abs(dot) >= 1.0:
return lerp(alpha, theta0_copy, theta1_copy)
# Calculate initial angle between v0 and v1
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
# Angle at timestep t
theta_t = theta_0 * alpha
sin_theta_t = np.sin(theta_t)
# Finish the slerp algorithm
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
result = s0 * theta0_copy + s1 * theta1_copy
del theta0_copy,theta1_copy
del theta0,theta1
return result
primary_model = torch.load('models\\model-aa-base.ckpt', map_location='cpu')
secondary_model = torch.load('models\\model-aa-waifu.ckpt', map_location='cpu')
theta_0 = primary_model['state_dict']
theta_1 = secondary_model['state_dict']
for key in set(theta_0.keys()).union(set(theta_1.keys())):
if 'model' in key and key in theta_0 and key in theta_1:
print(key)
theta_0[key] = slerp(theta_0[key], theta_1[key], (float(1.0) - 0.25))
if 'model' in key and key in theta_1 and key not in theta_0:
theta_0[key] = theta_1[key]
del theta_1[key]
del secondary_model
torch.save(primary_model, 'models\\model-aa-base-plus-waifu.ckpt')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment