Last active
November 25, 2023 12:28
-
-
Save Abusagit/5b444130e57b8e1272a03d58837879f1 to your computer and use it in GitHub Desktop.
At the time of creating this gist, any node sampling routine in DGL v. 1.1.2 creates message flow graphs which trigger `CUDA: an illegal memory access was encountered` in message passing. This code helps to neglect this error by constructing identical subgraph with the same nodes features for every given depth of message flow graphs. Also it add…
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dgl | |
from typing import Any | |
from copy import deepcopy | |
def construct_subgraph_from_blocks(blocks: list[Any], | |
batch_size:int, | |
node_attributes_to_copy: list[str], | |
) -> dgl.DGLGraph: | |
""" | |
Constructs a copy of a Message flow graphs (MFG), defined as a list of MFGs. | |
NOTE: this function is an example of constructing graph for node classification tasks, graph obly contains node features | |
params: | |
`blocks`: list of consecutive message flow graphs, len(blocks) == number of layers in graph convolution | |
`batch_size`: number of destination nodes | |
`node_attributes_to_copy`: list of names of node attributes to copy to a new graph | |
""" | |
merged_block = deepcopy(dgl.merge([dgl.block_to_graph(b) for b in blocks])) | |
row_coords, col_coords = merged_block.edges(form='uv', order='srcdst') | |
new_graph = dgl.graph(data=(row_coords, col_coords)) | |
for node_data_name in node_attributes_to_copy: | |
new_graph.ndata[node_data_name] = merged_block.srcdata[node_data_name] | |
# create mask marking only destination nodes, which are needed for | |
num_of_nodes = new_graph.num_nodes() | |
output_mask = torch.zeros(num_of_nodes).bool() | |
output_mask[:batch_size] = True | |
new_graph.ndata["output_mask"] = output_mask | |
return new_graph |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment