Skip to content

Instantly share code, notes, and snippets.

@ShomyLiu
Created December 31, 2018 08:27
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ShomyLiu/47c688095f0ac595c9bbd23f87d4a85b to your computer and use it in GitHub Desktop.
Save ShomyLiu/47c688095f0ac595c9bbd23f87d4a85b to your computer and use it in GitHub Desktop.
因子分解机Factorization Machine pytorch 实现
import torch
class FM_Layer(nn.Module):
def __init__(self, n=10, k=5):
```
n: 输入维度
k: factor的维度
```
super(FM_Layer, self).__init__()
self.n = n
self.k = k
self.linear = nn.Linear(self.n, 1) # 前两项线性层
self.V = nn.Parameter(torch.randn(self.n, self.k)) # 交互矩阵
def fm_layer(self, x):
```
:输入: x 为一个n维向量
:返回: 实数值
```
linear_part = self.linear(x)
interaction_part_1 = torch.mm(x, self.V)
interaction_part_1 = torch.pow(interaction_part_1, 2)
interaction_part_2 = torch.mm(torch.pow(x, 2), torch.pow(self.V, 2))
output = linear_part + torch.sum(0.5 * interaction_part_2 - interaction_part_1)
return output
def forward(self, x):
return self.fm_layer(x)
fm = FM_Layer(10, 5)
x = torch.randn(1, 10)
output = fm(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment