Skip to content

Instantly share code, notes, and snippets.

@Flunzmas
Created June 25, 2024 06:59
Show Gist options
  • Save Flunzmas/34a43a00497b4e98b8e538945d504066 to your computer and use it in GitHub Desktop.
Save Flunzmas/34a43a00497b4e98b8e538945d504066 to your computer and use it in GitHub Desktop.
This function converts a PyTorch Geometric `Data` object representing a directed graph into its line digraph representation.
def to_line_digraph(self, data: Data) -> Data:
"""
TODO can we make this more efficient by removing the for-loop?
"""
assert data.edge_index is not None
assert data.is_directed()
edge_index, edge_attr = data.edge_index, data.edge_attr
N, E = data.num_nodes, data.num_edges
edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes=data.num_nodes)
row, col = edge_index
new_edge_index = []
for i in range(E):
new_col_i = torch.nonzero(row == col[i])
new_row_i = i * torch.ones_like(new_col_i)
new_edge_index.append(torch.cat([new_row_i, new_col_i], dim=1))
new_edge_index = torch.cat(new_edge_index, dim=0).t()
data.edge_index = new_edge_index
data.x = edge_attr
data.num_nodes = E
data.edge_attr = None
return data
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment