Skip to content

Instantly share code, notes, and snippets.

@RaulPPelaez
Created May 13, 2024 14:02
Show Gist options
  • Save RaulPPelaez/36b6a3a4bbdb0c373beaf3c1376e8f49 to your computer and use it in GitHub Desktop.
Save RaulPPelaez/36b6a3a4bbdb0c373beaf3c1376e8f49 to your computer and use it in GitHub Desktop.
from tinygrad import Tensor, nn
from tinygrad.dtype import dtypes
import numpy as np
from typing import Optional
from utils import *
import pytest
def get_einsum_subscript(extra_dims: int) -> str:
result_subscripts = "i"
feature_subscripts = "j"
additional_letters = "klhmno"[: extra_dims + 1]
feature_subscripts += additional_letters
result_subscripts += additional_letters
einsum_subscript = "ij" + "," + feature_subscripts + "->" + result_subscripts
return einsum_subscript
def aggregate_edge_features(n_nodes: int, edge_features: Tensor, edge_index: Tensor):
r = edge_index == Tensor.arange(n_nodes).unsqueeze(-1)
extra_dims = edge_features.ndim - r.ndim
einsum_subscript = get_einsum_subscript(extra_dims)
node_features = Tensor.einsum(einsum_subscript, r, edge_features)
return node_features
def aggregate_edge_features_np(n_nodes: int, edge_features: Tensor, edge_index: Tensor):
r = edge_index == np.arange(n_nodes).reshape(-1, 1)
extra_dims = edge_features.ndim - r.ndim
einsum_subscript = get_einsum_subscript(extra_dims)
node_features = np.einsum(einsum_subscript, r, edge_features)
return node_features
@pytest.mark.parametrize("ndims", [1, 2, 3, 4, 5])
@pytest.mark.parametrize("last_dim_size", [1, 10])
@pytest.mark.parametrize("n_nodes", [1, 10])
@pytest.mark.parametrize("n_edges_per_node", [1, 10])
def test_aggregate_edge_features(ndims, last_dim_size, n_nodes, n_edges_per_node):
edge_index = Tensor.uniform(
n_edges_per_node * n_nodes, low=0, high=n_nodes, dtype=dtypes.int
)
extra_dims = [1] * ndims
extra_dims[-1] = last_dim_size
edge_features = Tensor.rand(edge_index.shape[0], *extra_dims)
node_features = aggregate_edge_features(n_nodes, edge_features, edge_index)
reference = aggregate_edge_features_np(
n_nodes, edge_features.numpy(), edge_index.numpy()
)
assert np.allclose(
node_features.numpy(),
reference,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment