Skip to content

Instantly share code, notes, and snippets.

@felixgwu
Last active January 17, 2019 20:26
Show Gist options
  • Save felixgwu/14636d7afce34f0bec78e96839b60889 to your computer and use it in GitHub Desktop.
Save felixgwu/14636d7afce34f0bec78e96839b60889 to your computer and use it in GitHub Desktop.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from lanczos_utils import check_dist
EPS = float(np.finfo(np.float32).eps)
__all__ = ['LanczosNet']
class LanczosNet(nn.Module):
def __init__(self, nfeat, nhid, nclass, dropout=0.5, num_layer=2, num_eig_vec=20,
spectral_filter_kind='MLP', short_diffusion_dist=[1, 2, 5, 7],
long_diffusion_dist=[10, 20, 30]):
super(LanczosNet, self).__init__()
self.input_dim = nfeat
self.hidden_dim = nhid
self.output_dim = nclass
self.num_layer = num_layer
self.dropout = dropout
self.short_diffusion_dist = check_dist(short_diffusion_dist)
self.long_diffusion_dist = check_dist(long_diffusion_dist)
self.max_short_diffusion_dist = max(
self.short_diffusion_dist) if self.short_diffusion_dist else None
self.max_long_diffusion_dist = max(
self.long_diffusion_dist) if self.long_diffusion_dist else None
self.num_scale_short = len(self.short_diffusion_dist)
self.num_scale_long = len(self.long_diffusion_dist)
self.num_eig_vec = num_eig_vec
self.spectral_filter_kind = spectral_filter_kind
dim_list = [self.input_dim] + [self.hidden_dim] * (num_layer - 1) + [self.output_dim]
self.filter = nn.ModuleList([
nn.Linear(dim_list[tt] * (
self.num_scale_short + self.num_scale_long + 1),
dim_list[tt + 1]) for tt in range(self.num_layer)
])
# spectral filters
if self.spectral_filter_kind == 'MLP' and self.num_scale_long > 0:
self.spectral_filter = nn.ModuleList([
nn.Sequential(*[
nn.Linear(self.num_scale_long, 128),
nn.ReLU(),
nn.Linear(128, self.num_scale_long)
]) for _ in range(self.num_layer)
])
self._init_param()
def _init_param(self):
for ff in self.filter:
if isinstance(ff, nn.Linear):
nn.init.xavier_uniform_(ff.weight.data)
if ff.bias is not None:
ff.bias.data.zero_()
if self.spectral_filter_kind == 'MLP' and self.num_scale_long > 0:
for ff in self.spectral_filter:
for f in ff:
if isinstance(f, nn.Linear):
nn.init.xavier_uniform_(f.weight.data)
if f.bias is not None:
f.bias.data.zero_()
def _get_spectral_filters(self, T_list, Q, layer_idx):
""" Construct Spectral Filters based on Lanczos Outputs
Args:
T_list: each element is of shape B X K
Q: shape B X N X K
Returns:
L: shape B X N X N X num_scale
"""
# multi-scale diffusion
L = []
# spectral filter
if self.spectral_filter_kind == 'MLP':
DD = torch.stack(
T_list, dim=2).view(Q.shape[0] * Q.shape[2], -1) # shape BK X D
DD = self.spectral_filter[layer_idx](DD).view(Q.shape[0], Q.shape[2],
-1) # shape B X K X D
for ii in range(self.num_scale_long):
tmp_DD = DD[:, :, ii].unsqueeze(1).repeat(1, Q.shape[1],
1) # shape B X N X K
L += [(Q * tmp_DD).bmm(Q.transpose(1, 2))]
else:
for ii in range(self.num_scale_long):
DD = T_list[ii].unsqueeze(1).repeat(1, Q.shape[1], 1) # shape B X N X K
L += [(Q * DD).bmm(Q.transpose(1, 2))]
return torch.stack(L, dim=3)
def forward(self, node_feat, L, D, V):
"""
shape parameters:
batch size = B
embedding dim = D
max number of nodes within one mini batch = N
number of edge types = E
number of predicted properties = P
number of approximated eigenvalues, i.e., Ritz values = K
Args:
node_feat: long tensor, shape B X N x D
L: float tensor, shape B X N X N X (E + 1)
D: float tensor, Ritz values, shape B X K
V: float tensor, Ritz vectors, shape B X N X K
label: float tensor, shape B X P
mask: float tensor, shape B X N
"""
if node_feat.dim() == 2:
squeeze = True
node_feat = node_feat.unsqueeze(0)
L = L.unsqueeze(0)
D = D.unsqueeze(0)
V = V.unsqueeze(0)
batch_size = node_feat.shape[0]
num_node = node_feat.shape[1]
D_pow_list = []
for ii in self.long_diffusion_dist:
D_pow_list += [torch.pow(D, ii)] # shape B X K
###########################################################################
# Graph Convolution
###########################################################################
state = node_feat
# propagation
for tt in range(self.num_layer):
msg = []
if self.num_scale_long > 0:
Lf = self._get_spectral_filters(D_pow_list, V, tt)
# short diffusion
if self.num_scale_short > 0:
tmp_state = state
for ii in range(1, self.max_short_diffusion_dist + 1):
tmp_state = torch.bmm(L, tmp_state)
if ii in self.short_diffusion_dist:
msg += [tmp_state]
# long diffusion
if self.num_scale_long > 0:
for ii in range(self.num_scale_long):
msg += [torch.bmm(Lf[:, :, :, ii], state)] # shape: B X N X D
msg += [torch.bmm(L, state)] # shape: B X N X D
msg = torch.cat(msg, dim=2).view(num_node * batch_size, -1)
state = F.relu(self.filter[tt](msg)).view(batch_size, num_node, -1)
state = F.dropout(state, self.dropout, training=self.training)
if squeeze:
state = state.squeeze(0)
return F.softmax(state, dim=-1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment