Skip to content

Instantly share code, notes, and snippets.

@ricsi98
Created July 3, 2024 18:50
Show Gist options
  • Save ricsi98/852d147463f5ee62d1b10d0aae0119a8 to your computer and use it in GitHub Desktop.
Save ricsi98/852d147463f5ee62d1b10d0aae0119a8 to your computer and use it in GitHub Desktop.
from typing import Literal
import numpy as np
import scipy
from scipy.sparse import csr_array
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import to_scipy_sparse_matrix
NormType = Literal["sym", "row", None]
def _normalize_adj(adj: csr_array, normalization: NormType):
if normalization is None:
return adj
deg = np.array(adj.sum(1)).flatten()
N = len(deg)
if normalization == "sym":
d_inv_sqrt = np.power(deg, -0.5)
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
d_mat_inv_sqrt = csr_array((d_inv_sqrt, (range(N), range(N))), shape=(N, N))
return d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt)
elif normalization == "row":
d_inv = np.power(deg, -1.0)
d_inv[np.isinf(d_inv)] = 0.0
d_mat_inv = csr_array((d_inv, (range(N), range(N))), shape=(N, N))
return d_mat_inv.dot(adj)
def _adj2laplacian(adj: csr_array):
deg = np.array(adj.sum(1)).flatten()
N = len(deg)
d_mat = csr_array((deg, (range(N), range(N))), shape=(N, N))
return d_mat - adj
def poly_features(
adj: csr_array,
features: np.ndarray,
identity: bool = True,
adj_powers: int = 2,
laplacian_powers: int = 2,
normalization: NormType = "sym",
):
adj = _normalize_adj(adj, normalization)
acc = []
if identity:
acc.append(features)
if adj_powers > 0:
acc.append(adj.dot(features))
for _ in range(2, adj_powers + 1):
acc.append(adj.dot(acc[-1]))
if laplacian_powers > 0:
laplacian = _adj2laplacian(adj)
acc.append(laplacian.dot(features))
for _ in range(2, laplacian_powers + 1):
acc.append(laplacian.dot(acc[-1]))
if isinstance(features, csr_array):
return scipy.sparse.hstack(acc)
else:
return np.hstack(acc)
data = Planetoid(root="data/Planetoid", name="Cora")[0]
adj_matrix = to_scipy_sparse_matrix(data.edge_index).tocsr()
x = data.x.numpy()
x_poly = poly_features(adj_matrix, x, adj_powers=2, laplacian_powers=2)
print(x_poly.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment