Created
May 2, 2019 13:06
-
-
Save AlexMRuch/a5d4a4edfb849f6348574f792e86d7cd to your computer and use it in GitHub Desktop.
UniformRandomMetaPathWalk_def.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Author: Alexander Ruch | |
# Modified from github.com/stellargraph/stellargraph/blob/develop/stellargraph/data/explorer.py | |
# Original stellargraph code written for networkx | |
# This script is substantially faster and parallelized | |
# Note: This script is still not as fast as original metapath2vec authors' dictionary-based approach | |
# Run with OPENBLAS_MAIN_FREE=1 python -i UniformRandomMetaPathWalk_def.py | |
""" | |
**** Run with OPENBLAS_MAIN_FREE=1 python -i UniformRandomMetaPathWalk_def.py **** | |
Performs metapath-driven uniform random walks on heterogeneous graphs. | |
Args: | |
graph: <graph> an undirected multirelational graph-tool graph object | |
n: <int> Total number of random walks per root node | |
length: <int> Maximum length of each random walk | |
metapaths: <list> List of lists of node labels that specify a metapath schema, e.g., | |
[['Author', 'Paper', 'Author'], ['Author, 'Paper', 'Venue', 'Paper', 'Author']] specifies two metapath | |
node_type: <str> The node attribute name that stores the node's type | |
seed: <int> Random number generator seed; default is None | |
Returns: | |
<list> List of lists of nodes ids for each of the random walks generated | |
""" | |
################################################################################ | |
################################################################################ | |
## IMPORT DEPENDENCIES | |
import os | |
from multiprocessing import Pool, Value, Lock, cpu_count | |
from datetime import datetime | |
from time import time | |
import numpy as np | |
import random | |
from graph_tool.all import * | |
import graph_tool as gt | |
## CREATE COUNTER CLASS FOR MULTIPROCESSING | |
class Counter(object): | |
def __init__(self, initval=0): | |
self.val = Value('i', initval) | |
self.lock = Lock() | |
def increment(self, incval=1): | |
with self.lock: | |
self.val.value += incval | |
def decriment(self, decval=1): | |
with self.lock: | |
self.val.value -= decval | |
def value(self): | |
with self.lock: | |
return self.val.value | |
## IMPORT DATA | |
# Set datasource path | |
datasource = '/media/seagate0/reddit/' # edit if necessary | |
graph_sm = datasource+"graphs/g_sampled.gt" # edit if necessary | |
print("Small graph source = ", graph_sm) | |
graph_med = datasource+"graphs/gs_SW_MultiRel_gt_med.gt" # edit if necessary | |
print("Medium graph source =", graph_med) | |
# Import graph | |
import_graph= input("Import small or medium graph? [s/m] ").lower() | |
if import_graph == "s": | |
print("Importing small graph...") | |
graph = gt.load_graph(graph_sm) | |
print("Loaded small graph") | |
print(" Nodes graph:", graph.num_vertices()) | |
print(" Edges graph:", graph.num_edges()) | |
if graph.is_directed() == True: | |
print("WARNING: graph is directed but should be undirected") | |
print(" Transforming directed graph to undirected graph...") | |
graph.set_directed(False) | |
elif import_graph == "m": | |
print("Importing medium graph...") | |
graph = gt.load_graph(graph_med) | |
print("Loaded medium graph") | |
print(" Nodes graph:", graph.num_vertices()) | |
print(" Edges graph:", graph.num_edges()) | |
if graph.is_directed() == True: | |
print("WARNING: graph is directed but should be undirected") | |
print(" Transforming directed graph to undirected graph...") | |
graph.set_directed(False) | |
else: | |
raise ValueError("BREAKING: please edit script to set correct graph source") | |
## DEFINE METAPATHS | |
metapath_set = input("Use full or min metapath set? [full/min] ").lower() | |
if metapath_set == "full": | |
print(" Loading full metapath set...") | |
metapaths = [ | |
# pull together subreddits with same submission/comment author | |
["subreddit", "submission", "author", "submission", "subreddit"], | |
["subreddit", "submission", "comment", "author", "comment", "submission", "subreddit"], | |
# pull together authors who submit/comment on same subreddit | |
["author", "submission", "subreddit", "submission", "author"], | |
["author", "comment", "submission", "subreddit", "submission", "comment", "author"], | |
# pull together authors who interact in same submission | |
["author", "submission", "comment", "author"], | |
["author", "comment", "submission", "author"], | |
["author", "comment", "submission", "comment", "author"], | |
# pull together submissions in the same subreddit | |
["submission", "subreddit", "submission"], | |
# pull together submissions by the same author | |
["submission", "author", "submission"], | |
# pull together comments by the same author | |
["comment", "author", "comment"], | |
# pull together comments on the same submission | |
["comment", "submission", "comment"] | |
] | |
window = len(max(metapaths)) | |
print(" Longest metapath length =", window) | |
print(" Finding root nodes in full set...") | |
nodes = graph.get_vertices() | |
nodes_total = graph.num_vertices() | |
print(" Root nodes in full set:", nodes_total) | |
elif metapath_set == "min": | |
print(" Loading minimum metapath set...") | |
metapaths = [ | |
# pull together subreddits with same submission/comment author | |
["subreddit", "submission", "author", "submission", "subreddit"], | |
["subreddit", "submission", "comment", "author", "comment", "submission", "subreddit"], | |
# pull together authors who submit/comment on same subreddit | |
["author", "submission", "subreddit", "submission", "author"], | |
["author", "comment", "submission", "subreddit", "submission", "comment", "author"], | |
# pull together authors who interact in same submission | |
["author", "submission", "comment", "author"], | |
["author", "comment", "submission", "author"], | |
["author", "comment", "submission", "comment", "author"] | |
] | |
window = len(max(metapaths)) | |
print(" Longest metapath length =", window) | |
print(" Finding root nodes in minimum set...") | |
graph_min = GraphView( | |
graph, | |
vfilt=lambda v: graph.vp.type[v]=="subreddit" or graph.vp.type[v]=="author" | |
) | |
nodes = graph_min.get_vertices() | |
nodes_total = graph_min.num_vertices() | |
print(" Root nodes in minimum set:", nodes_total) | |
else: | |
raise ValueError("ERROR: metapath not set as full or min:", metapath_set) | |
# DEFINE SAMPLING PARAMETERS | |
n = int(input("How many walks per node do you want to run? [n=100] ")) | |
print(" Running {} walks per node".format(n)) | |
length = int(input("What should the max length of each walk be? [length=10] ")) | |
print(" Walking a length of {} per walk".format(length)) | |
seed_set = input("Set random seed to 407? [y/n] ").lower() | |
if seed_set == "y": | |
seed = 407 | |
rs = random.seed(seed) | |
rs = random.Random(seed) | |
else: | |
seed = random.randint(0, 99**99) | |
rs = random.seed(seed) | |
rs = random.Random(seed) | |
print("Set random seed =", seed) | |
################################################################################ | |
################################################################################ | |
# Check validity of data and sampling parameters | |
def check_parameter_values(graph, nodes, n, length, metapaths, node_type, seed): | |
""" | |
Checks that the parameter values are valid or raises ValueError exceptions with a message indicating the | |
parameter (the first one encountered in the checks) with invalid value. | |
Args: | |
nodes: <list> The starting nodes as a list of node IDs. | |
n: <int> Number of walks per node id. | |
length: <int> Maximum length of of each random walk | |
metapaths: <list> List of lists of node labels that specify a metapath schema, e.g., | |
[['Author', 'Paper', 'Author'], ['Author, 'Paper', 'Venue', 'Paper', 'Author']] specifies two metapath | |
schemas of length 3 and 5 respectively. | |
node_type: <str> The node attribute name that stores the node's type | |
seed: <int> Random number generator seed | |
""" | |
if graph is None: | |
raise ValueError("A graph was not provided (parameter graph is None)") | |
if nodes is None: | |
raise ValueError("A list of starting node IDs was not provided (parameter nodes is None)") | |
if len(nodes) == 0: | |
raise ValueError("No starting node IDs given. An empty list will be returned as a result") | |
if n <= 0: | |
raise ValueError("The number of walks per starting node, n, should be a positive integer") | |
if type(n) != int: | |
raise ValueError("The number of walks per starting node, n, should be integer type") | |
if length <= 0: | |
raise ValueError("The walk length parameter, length, should be positive integer") | |
if type(length) != int: | |
raise ValueError("The walk length parameter, length, should be integer type") | |
if type(metapaths) != list: | |
raise ValueError("The metapaths parameter must be a list of lists") | |
for metapath in metapaths: | |
if type(metapath) != list: | |
raise ValueError("Each metapath must be list type of node labels") | |
if len(metapath) < 2: | |
raise ValueError("Each metapath must specify at least two node types") | |
for label in metapath: | |
if type(label) != str: | |
raise ValueError("Node labels in metapaths must be string type") | |
if metapath[0] != metapath[-1]: | |
raise ValueError("The first and last node type in a metapath should be the same") | |
if type(node_type) != str: | |
raise ValueError("The parameter label should be string type not as given:", type(node_type)) | |
if seed is not None: | |
if seed < 0: | |
raise ValueError("The random number generator seed value, seed, should be positive integer or None") | |
if type(seed) != int: | |
raise ValueError("The random number generator seed value, seed, should be integer type or None") | |
# Walk from nodes for a given length | |
def walk_nodes(node): | |
global graph | |
global length | |
global nodes_total | |
global rs | |
global time_start | |
# collect n walks from each node | |
walks = ([]) | |
node_pool = np.repeat(node,n) | |
for current_node in node_pool: | |
# select metapath beginning with current node's label type | |
filtered_metapaths = [ | |
metapath | |
for metapath in metapaths | |
if metapath[0] == graph.vp.type[current_node] | |
] | |
metapath = rs.choice(filtered_metapaths) | |
# augment metapath to be length long | |
metapath = metapath[1:] * ((length // (len(metapath) - 1)) + 1) | |
# hold walk data for walk; first node is root node | |
walk = ([]) | |
# for each walk, walk the maximum length | |
for d in range(length): | |
# Append node type to node name for results | |
#print(" Walking step {}: {}...".format(length_counter,graph.vp.type_name[current_node])) | |
walk.append(graph.vp.type_name[current_node]) | |
# get current node's neighbors | |
neighbors = graph.get_out_neighbors(current_node) | |
# find next current node from neighbors | |
n_attempts = min(100, len(neighbors)) | |
selected_next = False | |
while selected_next == False: | |
# if no neighbors of next type exist after n_attempts, break | |
if n_attempts <= 0: | |
#print(" No neighbors -- returning walk...") | |
walks.append(walk) | |
break | |
n_attempts -= 1 | |
selected_node = rs.randint(0, len(neighbors)) | |
if graph.vp.type[selected_node] == metapath[d]: | |
current_node = selected_node | |
selected_next = True | |
#print("Walk done -- returning walk...") | |
walks.append(walk) | |
# Return list of walks_all | |
nodes_counter.increment() | |
nodes_counter_n = nodes_counter.value() | |
walks_counter.increment(len(walks)) | |
mean_time = (time()-time_start) / nodes_counter_n | |
print(" Nodes sampled: {}/{}\n\t{}".format(nodes_counter_n,nodes_total,walks[-1][:10])) | |
print(" Mean time:", mean_time) | |
print(" Hrs. left:", (nodes_total-nodes_counter_n)/(60/mean_time)/60) | |
return walks | |
################################################################################ | |
################################################################################ | |
## SETEUP MULTIPROCESSING AND CHECK PARAMETERS | |
if __name__ == '__main__': | |
print("\n************************************************************\n") | |
print("Setting multiprocessing parameters...") | |
workers = cpu_count() | |
os.environ["OPENBLAS_MAIN_FREE"] = "1" | |
os.system('taskset -cp 0-%d %s' % (workers, os.getpid())) | |
print("Checking sampler parameters...") | |
check_parameter_values( | |
graph=graph, | |
nodes=nodes, | |
n=n, | |
length=length, | |
metapaths=metapaths, | |
node_type="type", | |
seed=seed | |
) | |
## RUN SAMPLER | |
print("\n************************************************************\n") | |
print("Sampling from {} nodes...".format(nodes_total)) | |
if import_graph == "s": | |
savesource = datasource+"samples/metapath_sample_sm_{}.txt".format(str(datetime.now())[:10]).replace("-","_") # edit if necessary | |
elif import_graph == "m": | |
savesource = datasource+"samples/metapath_sample_med_{}.txt".format(str(datetime.now())[:10]).replace("-","_") # edit if necessary | |
else: | |
savesource = datasource+"samples/metapath_sample_{}.txt".format(str(datetime.now())[:10]).replace("-","_") # edit if necessary | |
print(" Sample will be saved to", savesource, "...") | |
print(" Start runtime:", datetime.now()) | |
print("\n************************************************************\n") | |
time_start = time() | |
nodes_counter = Counter(0) | |
walks_counter = Counter(0) | |
with Pool(processes=workers, initargs=(nodes_counter,walks_counter)) as pool: | |
for walks in pool.map(walk_nodes, nodes, chunksize=1): | |
print("\n************************************************************\n") | |
print("Sample complete: {} walks sampled from {} nodes".format(walks_counter.value(),nodes_total)) | |
print("\n************************************************************\n") | |
# save sample | |
print("Saving sample to", savesource, "...") | |
with open(savesource, 'w') as f: | |
for walk in walks: | |
f.write(walk, "\n") | |
print(" Walk written -->", walk) | |
walks_counter.decriment(1) | |
print(" Walks remaining:", walks_counter.value()) | |
print("\n************************************************************\n") | |
print("Sample saved to", savesource) | |
################################################################################ | |
################################################################################ | |
print("\n************************************************************\n") | |
print("\n\n**** Finished script ****\n\n") | |
print(" Finish runtime:", datetime.now()) | |
print(" Total runtime: ", time()-time_start) | |
print(" Please check sample saved to", savesource) | |
print(" Suggested minimum metapath2vec window size =", window) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment