Skip to content

Instantly share code, notes, and snippets.

@lzqlzzq
Last active November 5, 2023 12:54
Show Gist options
  • Save lzqlzzq/3735f367a969fdb6302f33af3edacd18 to your computer and use it in GitHub Desktop.
Save lzqlzzq/3735f367a969fdb6302f33af3edacd18 to your computer and use it in GitHub Desktop.
Field-matrixed Factorization Machines implementation for pytorch
# https://arxiv.org/pdf/2102.12994.pdf
from typing import Dict
from itertools import permutations
import torch
from torch import nn
from math import sqrt
class FMFM(nn.Module):
def __init__(self,
embedding_dims: Dict[str, int],
output_dim: int,
norm_dim: bool = True):
super().__init__()
self.field_linears = nn.ModuleDict({
(i[0] + '_' + j[0]): nn.Linear(i[1], j[1], bias=False) \
for i, j in permutations(embedding_dims.items(), 2)})
self.projection = nn.Linear(sum(embedding_dims.values()) * (len(embedding_dims) - 1),
output_dim)
self.norm_dim = norm_dim
def forward(self, x):
if(self.norm_dim):
for k in x.keys():
x[k] /= sqrt(x[k].shape[-1])
transformed, embeddings = tuple(zip(*[
(self.field_linears[i[0] + '_' + j[0]](i[1]), j[1]) \
for i, j in permutations(sorted(x.items()), 2)]))
return self.projection(torch.cat(transformed, dim=-1) * \
torch.cat(embeddings, dim=-1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment