Skip to content

Instantly share code, notes, and snippets.

@Chillee
Last active May 14, 2023 09:45
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Chillee/a8d2070b1b7b3f97d8c87bac3c366f8e to your computer and use it in GitHub Desktop.
Save Chillee/a8d2070b1b7b3f97d8c87bac3c366f8e to your computer and use it in GitHub Desktop.
lora_example.py
import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize
from torch.utils._pytree import tree_map
class LoraTensor(object):
def __init__(self, weights, A, B):
self.weights = weights
self.A = A
self.B = B
def __repr__(self):
return f"LoraTensor(weight={self.weights}, A={self.A}, B={self.B})"
def tensor(self):
return self.weights + self.A @ self.B
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
def unwrap(e):
return e.tensor() if isinstance(e, LoraTensor) else e
if func == torch.nn.functional.linear and isinstance(args[1], LoraTensor):
orig_weight, A, B = (args[1].weights, args[1].A, args[1].B)
lora_part = A @ (B @ args[0])
return lora_part + func(args[0], orig_weight, args[2])
else:
args, kwargs = tree_map(unwrap, (args, kwargs))
return func(*args, **kwargs)
class LoraParametrization(nn.Module):
def __init__(self, A, B):
super().__init__()
self.A = torch.nn.Parameter(A)
self.B = torch.nn.Parameter(B)
def forward(self, W):
return LoraTensor(W, self.A, self.B)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
# bias is False just for simplicity
self.layer = torch.nn.Linear(8, 8, bias=False)
def forward(self, x):
return self.layer(x).relu()
inp = torch.randn(8, 8)
model = Model()
model.layer.weight.data.zero_()
out = model(inp)
model.layer.weight.requires_grad_(False)
parametrize.register_parametrization(model.layer, "weight", LoraParametrization(torch.ones(model.layer.weight.shape[0], 1), torch.ones(1, model.layer.weight.shape[1])), unsafe=True)
optim = torch.optim.SGD([param for param in model.parameters() if param.requires_grad], lr=0.1)
out = model(torch.randn(8, 8))
out.sum().backward()
optim.step()
print([(key, param.grad) for key, param in model.named_parameters() if param.requires_grad])
print([(key, param) for key, param in model.named_parameters()])
@Chillee
Copy link
Author

Chillee commented May 3, 2023

image

@SushantDaga
Copy link

SushantDaga commented May 14, 2023

Hi Horace, I use similar tensor based implementation during inference to serve different weight adapters in a single batch. Very useful when serving personalized adapters for each customer with a common base model.

Do you think there's benefit (mainly in terms of cost or speed) in using Tensor implementation during training as well?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment