Skip to content

Instantly share code, notes, and snippets.

@Flunzmas
Last active September 27, 2021 12:49
Show Gist options
  • Save Flunzmas/0d35f67a3f5e73bdb952e1960b4b2388 to your computer and use it in GitHub Desktop.
Save Flunzmas/0d35f67a3f5e73bdb952e1960b4b2388 to your computer and use it in GitHub Desktop.
PyG: Access individual graphs from a Batch object not created through Batch.from_data_list()
from torch_geometric.data import Data as GraphData
# ... load training data
train_data = None
# uses the following DataLoader: https://gist.github.com/Flunzmas/5a5c8c8fd553609359704be3174db793
data_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=4, drop_last=True)
for batch_idx, data in enumerate(data_loader):
for t, batch_at_timestep in enumerate(data):
# get batch_data and indices
batch_data = {x[0]: x[1] for x in iter(batch_at_timestep)}
batch_idx = batch_data.pop("batch")
unique_idx = batch_idx.unique()
B, BV = len(unique_idx), len(batch_idx) # batch_size, batch_size * |V|
# TODO assuming constant number of nodes in batch, and prob. even constant number of edges!
# Slice components of Batch object to get the individual data for a single graph.
# edge attribute tensors need special size treatment
batch_data["edge_index"] = batch_data["edge_index"].reshape(2, BV, -1).transpose(0, 1)
batch_data["edge_attr"] = batch_data["edge_attr"].reshape(BV, -1)
batch_data = {k: [v[batch_idx == i] for i in unique_idx] for k, v in batch_data.items()} # slice
batch_data["edge_index"] = [item.transpose(0, 1).reshape(2, -1) for item in batch_data["edge_index"]]
batch_data["edge_attr"] = [item.reshape(-1) for item in batch_data["edge_attr"]]
# Construct this timestep's individual graphs from sliced batch data
graphs_t, node_start_idx = [], 0
for b in range(B):
# revert batch aggregation in edge_index by subtracting current start node idx
graph_t_b = GraphData(x=batch_data["x"][b],
edge_index=batch_data["edge_index"][b] - node_start_idx,
edge_attr=batch_data["edge_attr"][b],
y=batch_data["y"][b])
graphs_t.append(graph_t_b)
node_start_idx += batch_data["x"][b].shape[0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment