Skip to content

Instantly share code, notes, and snippets.

@khanhnamle1994
Created February 22, 2020 22:06
Show Gist options
  • Save khanhnamle1994/e3dbf8e79bd7a563b0544ce54231312f to your computer and use it in GitHub Desktop.
Save khanhnamle1994/e3dbf8e79bd7a563b0544ce54231312f to your computer and use it in GitHub Desktop.
Factorization Machines class
import torch
from torch import nn
import torch.nn.functional as F
class MF(nn.Module):
def __call__(self, train_x):
# Pull out biases
biases = index_into(self.bias_feat.weight, train_x).squeeze().sum(dim=1)
# Initialize vector features using the feature weights
vector_features = index_into(self.feat.weight, train_x)
# Use factorization machines to pull out the interactions
interactions = factorization_machine(vector_features).squeeze().sum(dim=1)
# Final prediction is the sum of biases and interactions
prediction = biases + interactions
return prediction
def loss(self, prediction, target):
# Calculate the Mean Squared Error between target and prediction
loss_mse = F.mse_loss(prediction.squeeze(), target.squeeze())
# Compute L2 regularization over feature matrices
prior_feat = l2_regularize(self.feat.weight) * self.c_feat
# Add the MSE loss and feature regularization to get total loss
total = (loss_mse + prior_feat)
return total
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment