Skip to content

Instantly share code, notes, and snippets.

@Abusagit
Last active November 25, 2023 12:28
Show Gist options
  • Save Abusagit/5b444130e57b8e1272a03d58837879f1 to your computer and use it in GitHub Desktop.
Save Abusagit/5b444130e57b8e1272a03d58837879f1 to your computer and use it in GitHub Desktop.
At the time of creating this gist, any node sampling routine in DGL v. 1.1.2 creates message flow graphs which trigger `CUDA: an illegal memory access was encountered` in message passing. This code helps to neglect this error by constructing identical subgraph with the same nodes features for every given depth of message flow graphs. Also it add…
import dgl
from typing import Any
from copy import deepcopy
def construct_subgraph_from_blocks(blocks: list[Any],
batch_size:int,
node_attributes_to_copy: list[str],
) -> dgl.DGLGraph:
"""
Constructs a copy of a Message flow graphs (MFG), defined as a list of MFGs.
NOTE: this function is an example of constructing graph for node classification tasks, graph obly contains node features
params:
`blocks`: list of consecutive message flow graphs, len(blocks) == number of layers in graph convolution
`batch_size`: number of destination nodes
`node_attributes_to_copy`: list of names of node attributes to copy to a new graph
"""
merged_block = deepcopy(dgl.merge([dgl.block_to_graph(b) for b in blocks]))
row_coords, col_coords = merged_block.edges(form='uv', order='srcdst')
new_graph = dgl.graph(data=(row_coords, col_coords))
for node_data_name in node_attributes_to_copy:
new_graph.ndata[node_data_name] = merged_block.srcdata[node_data_name]
# create mask marking only destination nodes, which are needed for
num_of_nodes = new_graph.num_nodes()
output_mask = torch.zeros(num_of_nodes).bool()
output_mask[:batch_size] = True
new_graph.ndata["output_mask"] = output_mask
return new_graph
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment