Skip to content

Instantly share code, notes, and snippets.

@codeKgu
Created July 19, 2020 15:28
Show Gist options
  • Save codeKgu/54f7e779a76d45a27d8f902100dfac7a to your computer and use it in GitHub Desktop.
Save codeKgu/54f7e779a76d45a27d8f902100dfac7a to your computer and use it in GitHub Desktop.
snippets for TextGCN blog
import scipy.sparse as sp
import numpy as np
def init_node_feats(self, type, device):
if type == 'one_hot_init':
num_nodes = self.graph.shape[0]
identity = sp.identity(num_nodes)
ind0, ind1, values = sp.find(identity)
inds = np.stack((ind0, ind1), axis=0)
self.node_feats = torch.sparse_coo_tensor(inds, values, device=device, dtype=torch.float)
def forward(self, x, edge_index, edge_weight=None):
"""
for the PyTorch Geometric GCN implementation
see https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/gcn_conv.html#GCNConv
"""
if x.is_sparse:
x = torch.sparse.mm(x, self.weight)
else:
x = torch.matmul(x, self.weight)
if not self.cached or self.cached_result is None:
edge_index, norm = GCNConv.norm(edge_index, x.size(0), edge_weight,
self.improved, x.dtype)
self.cached_result = edge_index, norm
edge_index, norm = self.cached_result
return self.propagate(edge_index, x=x, norm=norm)
string = re.sub(r"http[s]?\:\/\/.[a-zA-Z0-9\.\/\_?=%&#\-\+!]+", " ", string)
string = re.sub(r"[^A-Za-z0-9()_+,!?:\'\`]", " ", string) # replace all non alpha numeric characters
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment