Skip to content

Instantly share code, notes, and snippets.

@BarclayII
Last active December 9, 2023 05:47
Show Gist options
  • Star 33 You must be signed in to star a gist
  • Fork 8 You must be signed in to fork a gist
  • Save BarclayII/ba88c3101a1ab784c5f27f73e2088ec1 to your computer and use it in GitHub Desktop.
Save BarclayII/ba88c3101a1ab784c5f27f73e2088ec1 to your computer and use it in GitHub Desktop.
PinSage example implementation
import pandas as pd
import dgl
import os
import torch
class MovieLens(object):
def __init__(self, directory):
'''
directory: path to movielens directory which should have the three
files:
users.dat
movies.dat
ratings.dat
'''
self.directory = directory
users = []
movies = []
ratings = []
# read users
with open(os.path.join(directory, 'users.dat')) as f:
for l in f:
id_, gender, age, occupation, zip_ = l.split('::')
users.append({
'id': int(id_),
'gender': gender,
'age': age,
'occupation': occupation,
'zip': zip_,
})
self.users = pd.DataFrame(users).set_index('id')
# read movies
with open(os.path.join(directory, 'movies.dat'), encoding='latin1') as f:
for l in f:
id_, title, genres = l.split('::')
genres_set = set(genres.split('|'))
data = {'id': int(id_), 'title': title}
for g in genres_set:
data[g] = True
movies.append(data)
self.movies = pd.DataFrame(movies).set_index('id')
# read ratings
with open(os.path.join(directory, 'ratings.dat')) as f:
for l in f:
user_id, movie_id, rating, timestamp = [int(_) for _ in l.split('::')]
ratings.append({
'user_id': user_id,
'movie_id': movie_id,
'rating': rating,
'timestamp': timestamp,
})
self.ratings = pd.DataFrame(ratings)
# randomly generate training-validation-test set on the ratings table
test_set = self.ratings.sample(frac=0.05, random_state=1).index
valid_set = self.ratings.sample(frac=0.05, random_state=2).index
valid_set = valid_set.difference(test_set)
self.ratings['valid'] = self.ratings.index.isin(valid_set)
self.ratings['test'] = self.ratings.index.isin(test_set)
def todglgraph(self):
'''
returns:
g, user_ids, movie_ids:
The DGL graph itself. Each edge has a binary feature "valid" and a binary
feature "test" indicating validation/test example.
The list of user IDs (node i corresponds to user user_ids[i])
The list of movie IDs (node i + len(user_ids) corresponds to movie movie_ids[i])
'''
user_ids = list(self.users.index)
movie_ids = list(self.movies.index)
user_ids_invmap = {id_: i for i, id_ in enumerate(user_ids)}
movie_ids_invmap = {id_: i for i, id_ in enumerate(movie_ids)}
g = dgl.DGLGraph()
g.add_nodes(len(user_ids) + len(movie_ids))
rating_user_vertices = [user_ids_invmap[id_] for id_ in self.ratings['user_id'].values]
rating_movie_vertices = [movie_ids_invmap[id_] + len(user_ids)
for id_ in self.ratings['movie_id'].values]
valid_tensor = torch.from_numpy(self.ratings['valid'].values.astype('uint8'))
test_tensor = torch.from_numpy(self.ratings['test'].values.astype('uint8'))
g.add_edges(rating_user_vertices,
rating_movie_vertices,
data={'valid': valid_tensor, 'test': test_tensor})
g.add_edges(rating_movie_vertices,
rating_user_vertices,
data={'valid': valid_tensor, 'test': test_tensor})
return g, user_ids, movie_ids
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
def get_embeddings(h, nodeset):
return h[nodeset]
def put_embeddings(h, nodeset, new_embeddings):
n_nodes = nodeset.shape[0]
n_features = h.shape[1]
return h.scatter(0, nodeset[:, None].expand(n_nodes, n_features), new_embeddings)
def random_walk_sampler(G, nodeset, n_traces, n_hops):
'''
G: DGLGraph
nodeset: 1D CPU Tensor of node IDs
n_traces: int
n_hops: int
return: 3D CPU Tensor or node IDs (n_nodes, n_traces, n_hops + 1)
'''
n_nodes = nodeset.shape[0]
traces = torch.zeros(n_nodes, n_traces, n_hops + 1, dtype=torch.int64)
for i in range(n_nodes):
for j in range(n_traces):
cur = nodeset[i]
for k in range(n_hops + 1):
traces[i, j, k] = cur
neighbors = G.successors(cur)
assert neighbors.shape[0] > 0
cur = neighbors[torch.randint(len(neighbors), ())]
return traces
def random_walk_distribution(G, nodeset, n_traces, n_hops):
n_nodes = nodeset.shape[0]
n_available_nodes = G.number_of_nodes()
traces = random_walk_sampler(G, nodeset, n_traces, n_hops)
visited_nodes = traces[:, :, 1:].view(n_nodes, -1) # (n_nodes, n_visited_other_nodes)
visited_counts = (
torch.zeros(n_nodes, n_available_nodes)
.scatter_add_(1, visited_nodes, torch.ones_like(visited_nodes, dtype=torch.float64)))
visited_prob = visited_counts / visited_counts.sum(1, keepdim=True)
return visited_prob
def random_walk_distribution_topt(G, nodeset, n_traces, n_hops, top_T):
'''
returns the top T important neighbors of each node in nodeset, as well as
the weights of the neighbors.
'''
visited_prob = random_walk_distribution(G, nodeset, n_traces, n_hops)
return visited_prob.topk(1, top_T)
def random_walk_nodeflow(G, nodeset, n_layers, n_traces, n_hops, top_T):
'''
returns a list of triplets (
"active" node IDs whose embeddings are computed at the i-th layer (num_nodes,)
weight of each neighboring node of each "active" node on the i-th layer (num_nodes, top_T)
neighboring node IDs for each "active" node on the i-th layer (num_nodes, top_T)
)
'''
nodeflow = []
cur_nodeset = nodeset
for i in reversed(range(n_layers)):
nb_weights, nb_nodes = random_walk_distribution_topt(G, nodeset, n_traces, n_hops, top_T)
nodeflow.insert((cur_nodeset, nb_weights, nb_nodes))
cur_nodeset = torch.cat([nb_nodes.view(-1), cur_nodeset]).unique()
return nodeflow
class PinSageConv(nn.Module):
def __init__(self, in_features, out_features, hidden_features):
super(PinSageConv, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.hidden_features = hidden_features
self.Q = nn.Linear(in_features, hidden_features)
self.W = nn.Linear(in_features + hidden_features, out_features)
def forward(self, h, nodeset, nb_nodes, nb_weights):
'''
h: node embeddings (num_total_nodes, in_features), or a container
of the node embeddings (for distributed computing)
nodeset: node IDs in this minibatch (num_nodes,)
nb_nodes: neighbor node IDs of each node in nodeset (num_nodes, num_neighbors)
nb_weights: weight of each neighbor node (num_nodes, num_neighbors)
return: new node embeddings (num_nodes, out_features)
'''
n_nodes, T = nb_nodes.shape[0]
h_nodeset = get_embeddings(h, nodeset) # (n_nodes, in_features)
h_neighbors = get_embeddings(h, nb_nodes.view(-1)).view(n_nodes, T, self.in_features)
h_neighbors = F.relu(self.Q(h_neighbors))
h_agg = (nb_weights[:, :, None] * h_neighbors).sum(1) / nb_weights.sum(1, keepdim=True)
h_concat = torch.cat([h_nodeset, h_agg], 1)
h_new = F.relu(self.W(h_concat))
h_new /= h_new.norm(dim=1, keepdim=True)
return h_new
class PinSage(nn.Module):
'''
Completes a multi-layer PinSage convolution
G: DGLGraph
feature_sizes: the dimensionality of input/hidden/output features
T: number of neighbors we pick for each node
n_traces: number of random walk traces to generate during sampling
n_hops: number of hops of each random walk trace during sampling
'''
def __init__(self, G, feature_sizes, T, n_traces, n_hops):
super(PinSage, self).__init__()
self.G = G
self.T = T
self.n_traces = n_traces
self.n_hops = n_hops
self.in_features = feature_sizes[0]
self.out_features = feature_sizes[-1]
self.n_layers = len(feature_sizes) - 1
self.convs = nn.ModuleList()
for i in range(self.n_layers):
self.convs.append(PinSageConv(
feature_sizes[i], feature_sizes[i+1], feature_sizes[i+1]))
def forward(self, h, nodeset):
'''
Given a complete embedding matrix h and a list of node IDs, return
the output embeddings of these node IDs.
h: node embeddings (num_total_nodes, in_features), or a container
of the node embeddings (for distributed computing)
nodeset: node IDs in this minibatch (num_nodes,)
return: new node embeddings (num_nodes, out_features)
'''
nodeflow = random_walk_nodeflow(self.G, nodeset, self.n_layers, self.n_traces, self.n_hops, self.T)
for i, (nodeset, nb_weights, nb_nodes) in enumerate(nodeflow):
new_embeddings = self.convs[i](h, nodeset, nb_nodes, nb_weights)
h = put_embeddings(h, nodeset, new_embeddings)
return get_embeddings(h, nodeset)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment