Skip to content

Instantly share code, notes, and snippets.

@Shiina18
Last active September 1, 2023 01:30
Show Gist options
  • Save Shiina18/220b5ba26a7a9acf2bd57ac0bb85e906 to your computer and use it in GitHub Desktop.
Save Shiina18/220b5ba26a7a9acf2bd57ac0bb85e906 to your computer and use it in GitHub Desktop.
This is probably a more clear implementation of [chihkuanyeh/Representer_Point_Selection](https://github.com/chihkuanyeh/Representer_Point_Selection/blob/master/compute_representer_vals.py) in PyTorch. Some details are different and changeable with comments attached.
import torch
import torch.nn as nn
class Classifier(nn.Module):
def __init__(self, pretrained_linear: nn.Linear):
super().__init__()
assert pretrained_linear.bias is not None # changeable
self.linear = nn.Linear(
in_features=pretrained_linear.in_features,
out_features=pretrained_linear.out_features,
bias=True,
)
self.linear.weight.data = pretrained_linear.weight.data.clone()
self.linear.bias.data = pretrained_linear.bias.data.clone()
def forward(self, x):
return self.linear(x)
def calculate_alphas(
classifier: Classifier, features, target_probs,
learning_rate=1, lambda_=0.003, num_epochs=40000,
device='cpu',
):
"""
features (N, m)
target_probs (N, num_classes)
alphas (N, num_classes)
"""
features = torch.Tensor(features).to(device)
target_probs = torch.Tensor(target_probs).to(device)
classifier = classifier.to(device)
# loss_fn = nn.CrossEntropyLoss() # changeable
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(classifier.parameters(), lr=learning_rate)
min_loss = float('inf')
min_grad = float('inf')
patience = 3000
steps_without_improvement = 0
best_weights = None
for epoch in range(num_epochs):
optimizer.zero_grad()
l2_norm = torch.sum(
torch.square(
torch.cat(
[
classifier.linear.weight.data,
classifier.linear.bias.data.unsqueeze(dim=1),
],
axis=1,
)
)
) # changeable, bias included
logits = classifier(features)
loss = loss_fn(logits, target_probs) + lambda_ * l2_norm
loss.backward()
optimizer.step()
grad = torch.cat(
[
classifier.linear.weight.grad,
classifier.linear.bias.grad.unsqueeze(dim=1),
],
axis=1,
)
# grad_norm = torch.norm(grad).item()
grad_norm = torch.mean(torch.abs(grad)).item()
if grad_norm < min_grad:
min_grad = grad_norm
best_weights = classifier.state_dict()
# TODO: stop criterion
if loss.item() < min_loss:
min_loss = loss.item()
steps_without_improvement = 0
else:
steps_without_improvement += 1
if (steps_without_improvement >= patience) and (min_grad < 1e-6):
break
if (epoch + 1) % 100 == 0:
print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}, Min Grad: {min_grad}')
classifier.load_state_dict(best_weights)
logits = classifier(features)
# changeable, different derivative for different loss_fn
# pred_probs = F.softmax(logits, dim=1)
pred_probs = torch.sigmoid(logits)
derivative = pred_probs - target_probs
num_samples = len(features)
alphas = - derivative / (2.0 * lambda_ * num_samples)
return alphas
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment