Skip to content

Instantly share code, notes, and snippets.

@AlexMRuch
Created May 2, 2019 13:06
Show Gist options
  • Save AlexMRuch/a5d4a4edfb849f6348574f792e86d7cd to your computer and use it in GitHub Desktop.
Save AlexMRuch/a5d4a4edfb849f6348574f792e86d7cd to your computer and use it in GitHub Desktop.
UniformRandomMetaPathWalk_def.py
# Author: Alexander Ruch
# Modified from github.com/stellargraph/stellargraph/blob/develop/stellargraph/data/explorer.py
# Original stellargraph code written for networkx
# This script is substantially faster and parallelized
# Note: This script is still not as fast as original metapath2vec authors' dictionary-based approach
# Run with OPENBLAS_MAIN_FREE=1 python -i UniformRandomMetaPathWalk_def.py
"""
**** Run with OPENBLAS_MAIN_FREE=1 python -i UniformRandomMetaPathWalk_def.py ****
Performs metapath-driven uniform random walks on heterogeneous graphs.
Args:
graph: <graph> an undirected multirelational graph-tool graph object
n: <int> Total number of random walks per root node
length: <int> Maximum length of each random walk
metapaths: <list> List of lists of node labels that specify a metapath schema, e.g.,
[['Author', 'Paper', 'Author'], ['Author, 'Paper', 'Venue', 'Paper', 'Author']] specifies two metapath
node_type: <str> The node attribute name that stores the node's type
seed: <int> Random number generator seed; default is None
Returns:
<list> List of lists of nodes ids for each of the random walks generated
"""
################################################################################
################################################################################
## IMPORT DEPENDENCIES
import os
from multiprocessing import Pool, Value, Lock, cpu_count
from datetime import datetime
from time import time
import numpy as np
import random
from graph_tool.all import *
import graph_tool as gt
## CREATE COUNTER CLASS FOR MULTIPROCESSING
class Counter(object):
def __init__(self, initval=0):
self.val = Value('i', initval)
self.lock = Lock()
def increment(self, incval=1):
with self.lock:
self.val.value += incval
def decriment(self, decval=1):
with self.lock:
self.val.value -= decval
def value(self):
with self.lock:
return self.val.value
## IMPORT DATA
# Set datasource path
datasource = '/media/seagate0/reddit/' # edit if necessary
graph_sm = datasource+"graphs/g_sampled.gt" # edit if necessary
print("Small graph source = ", graph_sm)
graph_med = datasource+"graphs/gs_SW_MultiRel_gt_med.gt" # edit if necessary
print("Medium graph source =", graph_med)
# Import graph
import_graph= input("Import small or medium graph? [s/m] ").lower()
if import_graph == "s":
print("Importing small graph...")
graph = gt.load_graph(graph_sm)
print("Loaded small graph")
print(" Nodes graph:", graph.num_vertices())
print(" Edges graph:", graph.num_edges())
if graph.is_directed() == True:
print("WARNING: graph is directed but should be undirected")
print(" Transforming directed graph to undirected graph...")
graph.set_directed(False)
elif import_graph == "m":
print("Importing medium graph...")
graph = gt.load_graph(graph_med)
print("Loaded medium graph")
print(" Nodes graph:", graph.num_vertices())
print(" Edges graph:", graph.num_edges())
if graph.is_directed() == True:
print("WARNING: graph is directed but should be undirected")
print(" Transforming directed graph to undirected graph...")
graph.set_directed(False)
else:
raise ValueError("BREAKING: please edit script to set correct graph source")
## DEFINE METAPATHS
metapath_set = input("Use full or min metapath set? [full/min] ").lower()
if metapath_set == "full":
print(" Loading full metapath set...")
metapaths = [
# pull together subreddits with same submission/comment author
["subreddit", "submission", "author", "submission", "subreddit"],
["subreddit", "submission", "comment", "author", "comment", "submission", "subreddit"],
# pull together authors who submit/comment on same subreddit
["author", "submission", "subreddit", "submission", "author"],
["author", "comment", "submission", "subreddit", "submission", "comment", "author"],
# pull together authors who interact in same submission
["author", "submission", "comment", "author"],
["author", "comment", "submission", "author"],
["author", "comment", "submission", "comment", "author"],
# pull together submissions in the same subreddit
["submission", "subreddit", "submission"],
# pull together submissions by the same author
["submission", "author", "submission"],
# pull together comments by the same author
["comment", "author", "comment"],
# pull together comments on the same submission
["comment", "submission", "comment"]
]
window = len(max(metapaths))
print(" Longest metapath length =", window)
print(" Finding root nodes in full set...")
nodes = graph.get_vertices()
nodes_total = graph.num_vertices()
print(" Root nodes in full set:", nodes_total)
elif metapath_set == "min":
print(" Loading minimum metapath set...")
metapaths = [
# pull together subreddits with same submission/comment author
["subreddit", "submission", "author", "submission", "subreddit"],
["subreddit", "submission", "comment", "author", "comment", "submission", "subreddit"],
# pull together authors who submit/comment on same subreddit
["author", "submission", "subreddit", "submission", "author"],
["author", "comment", "submission", "subreddit", "submission", "comment", "author"],
# pull together authors who interact in same submission
["author", "submission", "comment", "author"],
["author", "comment", "submission", "author"],
["author", "comment", "submission", "comment", "author"]
]
window = len(max(metapaths))
print(" Longest metapath length =", window)
print(" Finding root nodes in minimum set...")
graph_min = GraphView(
graph,
vfilt=lambda v: graph.vp.type[v]=="subreddit" or graph.vp.type[v]=="author"
)
nodes = graph_min.get_vertices()
nodes_total = graph_min.num_vertices()
print(" Root nodes in minimum set:", nodes_total)
else:
raise ValueError("ERROR: metapath not set as full or min:", metapath_set)
# DEFINE SAMPLING PARAMETERS
n = int(input("How many walks per node do you want to run? [n=100] "))
print(" Running {} walks per node".format(n))
length = int(input("What should the max length of each walk be? [length=10] "))
print(" Walking a length of {} per walk".format(length))
seed_set = input("Set random seed to 407? [y/n] ").lower()
if seed_set == "y":
seed = 407
rs = random.seed(seed)
rs = random.Random(seed)
else:
seed = random.randint(0, 99**99)
rs = random.seed(seed)
rs = random.Random(seed)
print("Set random seed =", seed)
################################################################################
################################################################################
# Check validity of data and sampling parameters
def check_parameter_values(graph, nodes, n, length, metapaths, node_type, seed):
"""
Checks that the parameter values are valid or raises ValueError exceptions with a message indicating the
parameter (the first one encountered in the checks) with invalid value.
Args:
nodes: <list> The starting nodes as a list of node IDs.
n: <int> Number of walks per node id.
length: <int> Maximum length of of each random walk
metapaths: <list> List of lists of node labels that specify a metapath schema, e.g.,
[['Author', 'Paper', 'Author'], ['Author, 'Paper', 'Venue', 'Paper', 'Author']] specifies two metapath
schemas of length 3 and 5 respectively.
node_type: <str> The node attribute name that stores the node's type
seed: <int> Random number generator seed
"""
if graph is None:
raise ValueError("A graph was not provided (parameter graph is None)")
if nodes is None:
raise ValueError("A list of starting node IDs was not provided (parameter nodes is None)")
if len(nodes) == 0:
raise ValueError("No starting node IDs given. An empty list will be returned as a result")
if n <= 0:
raise ValueError("The number of walks per starting node, n, should be a positive integer")
if type(n) != int:
raise ValueError("The number of walks per starting node, n, should be integer type")
if length <= 0:
raise ValueError("The walk length parameter, length, should be positive integer")
if type(length) != int:
raise ValueError("The walk length parameter, length, should be integer type")
if type(metapaths) != list:
raise ValueError("The metapaths parameter must be a list of lists")
for metapath in metapaths:
if type(metapath) != list:
raise ValueError("Each metapath must be list type of node labels")
if len(metapath) < 2:
raise ValueError("Each metapath must specify at least two node types")
for label in metapath:
if type(label) != str:
raise ValueError("Node labels in metapaths must be string type")
if metapath[0] != metapath[-1]:
raise ValueError("The first and last node type in a metapath should be the same")
if type(node_type) != str:
raise ValueError("The parameter label should be string type not as given:", type(node_type))
if seed is not None:
if seed < 0:
raise ValueError("The random number generator seed value, seed, should be positive integer or None")
if type(seed) != int:
raise ValueError("The random number generator seed value, seed, should be integer type or None")
# Walk from nodes for a given length
def walk_nodes(node):
global graph
global length
global nodes_total
global rs
global time_start
# collect n walks from each node
walks = ([])
node_pool = np.repeat(node,n)
for current_node in node_pool:
# select metapath beginning with current node's label type
filtered_metapaths = [
metapath
for metapath in metapaths
if metapath[0] == graph.vp.type[current_node]
]
metapath = rs.choice(filtered_metapaths)
# augment metapath to be length long
metapath = metapath[1:] * ((length // (len(metapath) - 1)) + 1)
# hold walk data for walk; first node is root node
walk = ([])
# for each walk, walk the maximum length
for d in range(length):
# Append node type to node name for results
#print(" Walking step {}: {}...".format(length_counter,graph.vp.type_name[current_node]))
walk.append(graph.vp.type_name[current_node])
# get current node's neighbors
neighbors = graph.get_out_neighbors(current_node)
# find next current node from neighbors
n_attempts = min(100, len(neighbors))
selected_next = False
while selected_next == False:
# if no neighbors of next type exist after n_attempts, break
if n_attempts <= 0:
#print(" No neighbors -- returning walk...")
walks.append(walk)
break
n_attempts -= 1
selected_node = rs.randint(0, len(neighbors))
if graph.vp.type[selected_node] == metapath[d]:
current_node = selected_node
selected_next = True
#print("Walk done -- returning walk...")
walks.append(walk)
# Return list of walks_all
nodes_counter.increment()
nodes_counter_n = nodes_counter.value()
walks_counter.increment(len(walks))
mean_time = (time()-time_start) / nodes_counter_n
print(" Nodes sampled: {}/{}\n\t{}".format(nodes_counter_n,nodes_total,walks[-1][:10]))
print(" Mean time:", mean_time)
print(" Hrs. left:", (nodes_total-nodes_counter_n)/(60/mean_time)/60)
return walks
################################################################################
################################################################################
## SETEUP MULTIPROCESSING AND CHECK PARAMETERS
if __name__ == '__main__':
print("\n************************************************************\n")
print("Setting multiprocessing parameters...")
workers = cpu_count()
os.environ["OPENBLAS_MAIN_FREE"] = "1"
os.system('taskset -cp 0-%d %s' % (workers, os.getpid()))
print("Checking sampler parameters...")
check_parameter_values(
graph=graph,
nodes=nodes,
n=n,
length=length,
metapaths=metapaths,
node_type="type",
seed=seed
)
## RUN SAMPLER
print("\n************************************************************\n")
print("Sampling from {} nodes...".format(nodes_total))
if import_graph == "s":
savesource = datasource+"samples/metapath_sample_sm_{}.txt".format(str(datetime.now())[:10]).replace("-","_") # edit if necessary
elif import_graph == "m":
savesource = datasource+"samples/metapath_sample_med_{}.txt".format(str(datetime.now())[:10]).replace("-","_") # edit if necessary
else:
savesource = datasource+"samples/metapath_sample_{}.txt".format(str(datetime.now())[:10]).replace("-","_") # edit if necessary
print(" Sample will be saved to", savesource, "...")
print(" Start runtime:", datetime.now())
print("\n************************************************************\n")
time_start = time()
nodes_counter = Counter(0)
walks_counter = Counter(0)
with Pool(processes=workers, initargs=(nodes_counter,walks_counter)) as pool:
for walks in pool.map(walk_nodes, nodes, chunksize=1):
print("\n************************************************************\n")
print("Sample complete: {} walks sampled from {} nodes".format(walks_counter.value(),nodes_total))
print("\n************************************************************\n")
# save sample
print("Saving sample to", savesource, "...")
with open(savesource, 'w') as f:
for walk in walks:
f.write(walk, "\n")
print(" Walk written -->", walk)
walks_counter.decriment(1)
print(" Walks remaining:", walks_counter.value())
print("\n************************************************************\n")
print("Sample saved to", savesource)
################################################################################
################################################################################
print("\n************************************************************\n")
print("\n\n**** Finished script ****\n\n")
print(" Finish runtime:", datetime.now())
print(" Total runtime: ", time()-time_start)
print(" Please check sample saved to", savesource)
print(" Suggested minimum metapath2vec window size =", window)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment