Skip to content

Instantly share code, notes, and snippets.

@rusty1s
Last active October 5, 2022 04:32
Show Gist options
  • Save rusty1s/d2a508b7c91a39ac9da9cfd2868b0841 to your computer and use it in GitHub Desktop.
Save rusty1s/d2a508b7c91a39ac9da9cfd2868b0841 to your computer and use it in GitHub Desktop.
[Community Sprint] Add missing type hints and TorchScript support

We are kicking off our very first community sprint!

The community sprint resolves around adding missing type hints and TorchScript support for various functions across PyG, aiming to improve and clean-up our core codebase. Each individual contribution is designed to only take around 30 minutes to two hours to complete.

The sprint begins Wednesday October 12th with a kick off meeting at 8am PST. The community sprint will last 2 weeks and we will have another live hangouts when the sprint has completed. If you are interested in helping out, please also join our PyG slack channel #community-sprint-type-hints for more information.

🚀 Add missing type hints and TorchScript support

Type hints are currently used inconsistently in the torch-geometric repository, and it would be nice to make them a complete, consistent thing across all datasets, models and utilities. Adding type hint support in models also helps us to improve our TorchScript coverage across layers and models provided in nn.*.

Example

Take a look at the current implementation of contains_isolated_nodes:

def contains_isolated_nodes(edge_index, num_nodes=None):
    num_nodes = maybe_num_nodes(edge_index, num_nodes)
    edge_index, _ = remove_self_loops(edge_index)
    return torch.unique(edge_index.view(-1)).numel() < num_nodes

Adding type hints support to the function signature helps us to better understand its input and output, improving code readability:

def contains_isolated_nodes(
   edge_index: Tensor,
   num_nodes: Optional[int] = None,
) -> bool:
   ...

Importantly, it also lets us use it as part of a TorchScript model. Without it, all arguments that miss type hints are expected to be PyTorch tensors (which is clearly not the case for the num_nodes argument). Without it, torch.jit.script compilation will fail:

import torch

from torch_geometric.utils import contains_isolated_nodes

contains_isolated_nodes = torch.jit.script(contains_isolated_nodes)

contains_isolated_nodes(torch.tensor([[0, 1, 0], [1, 0, 0]])
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File ".../pytorch_geometric/torch_geometric/utils/isolated.py", line 29, in contains_isolated_nodes
    num_nodes = maybe_num_nodes(edge_index, num_nodes)
                ~~~~~~~~~~~~~~~ <--- HERE
    edge_index, _ = remove_self_loops(edge_index)
    return torch.unique(edge_index.view(-1)).numel() < num_nodes
RuntimeError: Cannot input a tensor of dimension other than 0 as a scalar argument

Guide to contributing

See here for a basic example to follow.

  1. Ensure you have read our contributing guidelines.
  2. Claim the functionality/model you want to improve. More information on this will follow soon.
  3. Implement the changes as in pyg-team/pytorch_geometric#5603. At best, ensure in a test that the model/function is convertable to a TorchScript program and that it results in the same output. It is okay to add type hint support for multiple functions/models within a single PR as long as you have assigned yourself to each of them, and the number of file changes stays at a reasonable number to ease reviewing (e.g., not more than 10 touched files).
  4. Open a PR to the PyG repository and name it: "[Type Hints] {model_name/function_name}". In addition, ensure that documentation is rendered properly (CI will build the documentation automatically for you). Afterwards, add your PR number to the "Improved type hint support" line in CHANGELOG.md.

Tips for making your PR

  • If you are unfamiliar with how type hints work, you can read the Python library documentation on them, but it is probably even easier to just look at another PR that added them.
  • The types will usually be obvious, e.g., Tensor, bool, int, float. Wrap them within Optional[*] whenever the argument can be None.
  • Specialized PyG type hints (e.g., Adj) are defined in typing.py.
  • In some rare cases, type hints are challenging to add, e.g., whenever a model/function supports a Union of different types (which may be the case for functions/models that also support SparseTensor, e.g., edge_index: Union[Tensor, SparseTensor]). In that case, TorchScript support can be achieved via the torch.jit._overload decorator. See here for an example.
  • The corresponding tests of PyG models and functions can be found in the test/ directory. For example, tests for torch_geometric/utils/isolated.py can be found in test/utils/test_isolated.py. You can run individual test files via pytest test/utils/test_isolated.py. It is only necessary to test TorchScript support for the nn.* and utils.* packages. No changes necessary for datasets.* and transforms.* packages.
  • Ensure that the TorchScript variant compiles and achieves the same output:
    from torch_geometric.testing import is_full_test
    
    ...
    
    edge_index = torch.tensor([[0, 1, 2, 0], [1, 0, 2, 0]])
    assert contains_isolated_nodes(edge_index)
      
    if is_full_test():
        jit = torch.jit.script(contains_isolated_nodes)
        assert jit(edge_index)
    Note that we generally gate TorchScript tests behind is_full_test() which guarantees that TorchScript tests are only run nightly. You can enable full tests locally via the FULL_TEST=1 environment variable, e.g., FULL_TEST=1 pytest test/utils/test_isolated.py. It is only necessary to ensure TorchScript support for the nn.* and utils.* packages. No changes necessary for datasets.* and transforms.* packages.

Functions/Model to update

This list may be incomplete. If you still find a function without missing type hints/TorchScript tests, please let us know or add them on your own.

  • nn.MetaLayer
  • nn.InstanceNorm
  • nn.GraphSizeNorm
  • nn.MessageNorm
  • nn.DiffGroupNorm
  • nn.TopKPooling
  • nn.SAGPooling
  • nn.EdgePooling
  • nn.PANPooling
  • nn.max_pool (TorchScript support not possible)
  • nn.avg_pool (TorchScript support not possible)
  • nn.max_pool_x
  • nn.max_pool_neighbor_x (TorchScript support not possible)
  • nn.avg_pool_x
  • nn.avg_pool_neighbor_x (TorchScript support not possible)
  • nn.Node2Vec
  • nn.DeepGraphInfomax
  • nn.InnerProductEncoder
  • nn.GAE
  • nn.VGAE
  • nn.ARGA
  • nn.ARGVA
  • nn.SignedGCN
  • nn.RENet
  • nn.GraphUNet
  • nn.SchNet
  • nn.DimeNet
  • nn.GNNExplainer
  • nn.DeepGCNLayer
  • nn.AttentiveFP
  • nn.DenseGCNConv
  • nn.DenseGINConv
  • nn.DenseGraphConv
  • nn.DenseSAGEConv
  • nn.dense_diff_pool
  • nn.dense_mincut_pool
  • datasets.NELL
  • datasets.PPI
  • datasets.Reddit
  • datasets.Reddit2
  • datasets.Yelp
  • datasets.QM7b
  • datasets.ZINC
  • datasets.MoleculeNet
  • datasets.MNISTSuperpixels
  • datasets.ShapeNet
  • datasets.ModelNet
  • datasets.SHREC2016
  • datasets.TOSCA
  • datasets.PCPNetDataset
  • datasets.S3DIS
  • datasets.ICEWS18
  • datasets.WILLOWObjectClass
  • datasets.PascalVOCKeypoints
  • datasets.PascalPF
  • datasets.SNAPDataset
  • datasets.SuiteSparseMatrixCollection
  • datasets.WordNet18
  • datasets.WebKB
  • datasets.JODIEDataset
  • datasets.MixHopSyntheticDataset
  • datasets.UPFD
  • transforms.Distance
  • transforms.Cartesian
  • transforms.LocalCartesian
  • transforms.Polar
  • transforms.Spherical
  • transforms.PointPairFeatures
  • transforms.OneHotDegree
  • transforms.TargetIndegree
  • transforms.RandomJitter
  • transforms.RandomFlip
  • transforms.RandomScale
  • transforms.RandomRotate
  • transforms.RandomShear
  • transforms.KNNGraph
  • transforms.FaceToEdge
  • transforms.SamplePoints
  • transforms.FixedPoints
  • transforms.ToDense
  • transforms.LaplacianLambdaMax
  • transforms.ToSLIC
  • transforms.GDC
  • transforms.SIGN
  • transforms.SVDFeatureReduction
  • utils.remove_isolated_nodes
  • utils.get_laplacian
  • utils.to_dense_adj
  • utils.dense_to_sparse
  • utils.normalized_cut
  • utils.grid
  • utils.geodesic_distance (TorchScript support not possible)
  • utils.tree_decomposition (TorchScript support not possible)
  • utils.to_scipy_sparse_matrix (TorchScript support not possible)
  • utils.from_scipy_sparse_matrix (TorchScript support not possible)
  • utils.to_networkx (TorchScript support not possible)
  • utils.from_networkx (TorchScript support not possible)
  • utils.erdos_renyi_graph
  • utils.stochastic_blockmodel_graph
  • utils.barabasi_albert_graph
  • utils.train_test_split_edges
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment