Skip to content

Instantly share code, notes, and snippets.

@masaponto
Created January 4, 2021 00:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save masaponto/2d235fe63c98e659fa69c6b807ed0a9a to your computer and use it in GitHub Desktop.
Save masaponto/2d235fe63c98e659fa69c6b807ed0a9a to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using backend: pytorch\n"
]
}
],
"source": [
"import dgl\n",
"import torch\n",
"import numpy as np\n",
"\n",
"n_users = 1000\n",
"n_items = 500\n",
"n_follows = 3000\n",
"n_clicks = 5000\n",
"n_dislikes = 500\n",
"n_hetero_features = 10\n",
"n_user_classes = 5\n",
"n_max_clicks = 10\n",
"\n",
"follow_src = np.random.randint(0, n_users, n_follows)\n",
"follow_dst = np.random.randint(0, n_users, n_follows)\n",
"click_src = np.random.randint(0, n_users, n_clicks)\n",
"click_dst = np.random.randint(0, n_items, n_clicks)\n",
"dislike_src = np.random.randint(0, n_users, n_dislikes)\n",
"dislike_dst = np.random.randint(0, n_items, n_dislikes)\n",
"\n",
"hetero_graph = dgl.heterograph({\n",
" ('user', 'follow', 'user'): (follow_src, follow_dst),\n",
" ('user', 'followed-by', 'user'): (follow_dst, follow_src),\n",
" ('user', 'click', 'item'): (click_src, click_dst),\n",
" ('item', 'clicked-by', 'user'): (click_dst, click_src),\n",
" ('user', 'dislike', 'item'): (dislike_src, dislike_dst),\n",
" ('item', 'disliked-by', 'user'): (dislike_dst, dislike_src)})\n",
"\n",
"hetero_graph.nodes['user'].data['feature'] = torch.randn(n_users, n_hetero_features)\n",
"hetero_graph.nodes['item'].data['feature'] = torch.randn(n_items, n_hetero_features)\n",
"hetero_graph.nodes['user'].data['label'] = torch.randint(0, n_user_classes, (n_users,))\n",
"hetero_graph.edges['click'].data['label'] = torch.randint(1, n_max_clicks, (n_clicks,)).float()\n",
"# randomly generate training masks on user nodes and click edges\n",
"hetero_graph.nodes['user'].data['train_mask'] = torch.zeros(n_users, dtype=torch.bool).bernoulli(0.6)\n",
"hetero_graph.edges['click'].data['train_mask'] = torch.zeros(n_clicks, dtype=torch.bool).bernoulli(0.6)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{('item', 'clicked-by', 'user'): {}, ('item', 'disliked-by', 'user'): {}, ('user', 'click', 'item'): {'label': tensor([5., 4., 1., ..., 4., 2., 8.]), 'train_mask': tensor([False, True, True, ..., True, False, False])}, ('user', 'dislike', 'item'): {}, ('user', 'follow', 'user'): {}, ('user', 'followed-by', 'user'): {}}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hetero_graph.edata"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Graph(num_nodes={'item': 500, 'user': 1000},\n",
" num_edges={('item', 'clicked-by', 'user'): 5000, ('item', 'disliked-by', 'user'): 500, ('user', 'click', 'item'): 5000, ('user', 'dislike', 'item'): 500, ('user', 'follow', 'user'): 3000, ('user', 'followed-by', 'user'): 3000},\n",
" metagraph=[('item', 'user', 'clicked-by'), ('item', 'user', 'disliked-by'), ('user', 'item', 'click'), ('user', 'item', 'dislike'), ('user', 'user', 'follow'), ('user', 'user', 'followed-by')])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hetero_graph"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['clicked-by', 'disliked-by', 'click', 'dislike', 'follow', 'followed-by']"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hetero_graph.etypes"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1000, 10])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hetero_graph.nodes[\"user\"].data[\"feature\"].shape"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import dgl.nn as dglnn\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"class StochasticTwoLayerRGCN(nn.Module):\n",
" def __init__(self, in_feat, hidden_feat, out_feat, rel_names):\n",
" super().__init__()\n",
" self.conv1 = dglnn.HeteroGraphConv({\n",
" rel : dglnn.GraphConv(in_feat, hidden_feat, norm='right')\n",
" for rel in rel_names\n",
" })\n",
" self.conv2 = dglnn.HeteroGraphConv({\n",
" rel : dglnn.GraphConv(hidden_feat, out_feat, norm='right')\n",
" for rel in rel_names\n",
" })\n",
"\n",
" def forward(self, blocks, x):\n",
" x = self.conv1(blocks[0], x)\n",
" x = self.conv2(blocks[1], x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"class ScorePredictor(nn.Module):\n",
" def forward(self, edge_subgraph, x):\n",
" with edge_subgraph.local_scope():\n",
" edge_subgraph.ndata['h'] = x\n",
" for etype in edge_subgraph.canonical_etypes:\n",
" edge_subgraph.apply_edges(\n",
" dgl.function.u_dot_v('h', 'h', 'score'), etype=etype)\n",
" return edge_subgraph.edata['score']\n",
"\n",
"class Model(nn.Module):\n",
" def __init__(self, in_features, hidden_features, out_features, num_classes,\n",
" etypes):\n",
" super().__init__()\n",
" self.rgcn = StochasticTwoLayerRGCN(\n",
" in_features, hidden_features, out_features, etypes)\n",
" self.pred = ScorePredictor()\n",
"\n",
" def forward(self, positive_graph, negative_graph, blocks, x):\n",
" x = self.rgcn(blocks, x)\n",
" pos_score = self.pred(positive_graph, x)\n",
" neg_score = self.pred(negative_graph, x)\n",
" return pos_score, neg_score"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['clicked-by', 'disliked-by', 'click', 'dislike', 'follow', 'followed-by']"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hetero_graph.etypes"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5000"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hetero_graph.num_edges(\"click\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5000"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hetero_graph.num_edges(\"clicked-by\")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 0, 1, 2, ..., 4997, 4998, 4999])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.arange(hetero_graph.num_edges(\"click\"))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"train_eid_dict = {\"click\": torch.arange(hetero_graph.num_edges(\"click\")),\n",
" \"clicked-by\": torch.arange(hetero_graph.num_edges(\"clicked-by\")),\n",
" \"dislike\": torch.arange(hetero_graph.num_edges(\"dislike\")),\n",
" \"disliked-by\": torch.arange(hetero_graph.num_edges(\"disliked-by\")),\n",
" \"follow\": torch.arange(hetero_graph.num_edges(\"follow\")),\n",
" \"followed-by\": torch.arange(hetero_graph.num_edges(\"followed-by\"))}\n",
"\n",
"sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)\n",
"dataloader = dgl.dataloading.EdgeDataLoader(\n",
" hetero_graph, train_eid_dict, sampler,\n",
" exclude='reverse_types',\n",
" reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click', \n",
" \"dislike\": \"disliked-by\", \"disliked-by\":\"dislike\" ,\n",
" \"follow\": \"followed-by\", \"followed-by\": \"follow\"},\n",
" negative_sampler=dgl.dataloading.negative_sampler.Uniform(5),\n",
" batch_size=1024,\n",
" shuffle=True,\n",
" drop_last=False,\n",
" num_workers=4)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"def compute_loss(pos_score, neg_score, canonical_etypes):\n",
" # Margin loss\n",
" all_losses = []\n",
" for given_type in canonical_etypes:\n",
" n_edges = pos_score[given_type].shape[0]\n",
" if n_edges == 0:\n",
" continue\n",
" all_losses.append((1 - neg_score[given_type].view(n_edges, -1) + pos_score[given_type].unsqueeze(1)).clamp(min=0).mean())\n",
" return torch.stack(all_losses, dim=0).mean()\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"start\n",
"0 1.6091350317001343\n",
"1 1.3787392377853394\n",
"2 1.1582247018814087\n",
"3 1.1167367696762085\n",
"4 0.9966425895690918\n",
"5 0.9260237812995911\n",
"6 0.9521988034248352\n",
"7 0.856490433216095\n",
"8 0.9148068428039551\n",
"9 0.8671050071716309\n",
"10 0.8588862419128418\n",
"11 0.8536972999572754\n",
"12 0.8532518744468689\n",
"13 0.8321673274040222\n",
"14 0.7932301163673401\n",
"15 0.8196800351142883\n",
"16 0.8536006808280945\n",
"17 0.7718626856803894\n",
"18 0.7967429161071777\n",
"19 0.7815446257591248\n",
"20 0.7920117378234863\n",
"21 0.76273113489151\n",
"22 0.8009796738624573\n",
"23 0.7884901165962219\n",
"24 0.7748621106147766\n",
"25 0.7341139316558838\n",
"26 0.7425366044044495\n",
"27 0.7425730228424072\n",
"28 0.7284682393074036\n",
"29 0.7452893257141113\n",
"30 0.7855545878410339\n",
"31 0.7431566119194031\n",
"32 0.7610621452331543\n",
"33 0.7399067282676697\n",
"34 0.751685619354248\n",
"35 0.7471165657043457\n",
"36 0.7297484278678894\n",
"37 0.7272677421569824\n",
"38 0.7470314502716064\n",
"39 0.6765744686126709\n",
"40 0.7257323861122131\n",
"41 0.6735861301422119\n",
"42 0.7471463084220886\n",
"43 0.6931560635566711\n",
"44 0.698866605758667\n",
"45 0.7085348963737488\n",
"46 0.7206322550773621\n",
"47 0.7228112816810608\n",
"48 0.7575018405914307\n",
"49 0.7389569282531738\n"
]
}
],
"source": [
"in_features = 10\n",
"hidden_features = 50\n",
"out_features = 10\n",
"num_classes = None\n",
"etypes = hetero_graph.etypes\n",
"\n",
"model = Model(in_features, hidden_features, out_features, num_classes, etypes)\n",
"#model = model.cuda()\n",
"opt = torch.optim.Adam(model.parameters())\n",
"\n",
"print(\"start\")\n",
"i= 0\n",
"epoch = 50\n",
"for i in range(epoch):\n",
" for input_nodes, positive_graph, negative_graph, blocks in dataloader:\n",
"\n",
" #blocks = [b.to(torch.device('cuda')) for b in blocks]\n",
" # positive_graph = positive_graph.to(torch.device('cuda'))\n",
" #negative_graph = negative_graph.to(torch.device('cuda'))\n",
" input_features = blocks[0].srcdata['feature']\n",
" \n",
" pos_score, neg_score = model(positive_graph, negative_graph, blocks, input_features)\n",
" loss = compute_loss(pos_score, neg_score, hetero_graph.canonical_etypes)\n",
" opt.zero_grad()\n",
" loss.backward()\n",
" opt.step()\n",
" \n",
" print(i, loss.item())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"n_users = 1000\n",
"n_apps = 1000\n",
"n_shops = 1000\n",
"\n",
"n_appuses = 5000\n",
"n_pays = 3000\n",
"n_hetero_features = 10\n",
"\n",
"####\n",
"#n_users = 3\n",
"#n_apps = 10\n",
"#n_shops = 10\n",
"\n",
"#n_appuses = 6\n",
"#n_pays = 5\n",
"#n_hetero_features = 10\n",
"####\n",
"\n",
"\n",
"appuse_src = np.random.randint(0, n_users, n_appuses)\n",
"appuse_dst = np.random.randint(0, n_apps, n_appuses)\n",
"pay_src = np.random.randint(0, n_users, n_shops)\n",
"pay_dst = np.random.randint(0, n_pays, n_shops)\n",
"\n",
"\n",
"user_graph = dgl.heterograph({\n",
" ('user', 'appuse', 'app'): (appuse_src, appuse_dst),\n",
" ('app', 'usedby', 'user'): (appuse_dst, appuse_src),\n",
" \n",
" # ('user', 'pay', 'shop'): (pay_src, pay_dst),\n",
" #('shop', 'payedby', 'user'): (pay_dst, pay_src)\n",
"})\n",
"\n",
"user_graph.nodes['user'].data['feature'] = torch.randn(n_users, n_hetero_features)\n",
"user_graph.nodes['app'].data['feature'] = torch.randn(n_apps, n_hetero_features)\n",
"#user_graph.nodes['shop'].data['feature'] = torch.randn(n_shops, n_hetero_features)\n",
"\n",
"#hetero_graph.nodes['user'].data['label'] = torch.randint(0, n_user_classes, (n_users,))\n",
"#hetero_graph.edges['click'].data['label'] = torch.randint(1, n_max_clicks, (n_clicks,)).float()\n",
"# randomly generate training masks on user nodes and click edges\n",
"#hetero_graph.nodes['user'].data['train_mask'] = torch.zeros(n_users, dtype=torch.bool).bernoulli(0.6)\n",
"#hetero_graph.edges['click'].data['train_mask'] = torch.zeros(n_clicks, dtype=torch.bool).bernoulli(0.6)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Graph(num_nodes={'app': 1000, 'user': 1000},\n",
" num_edges={('app', 'usedby', 'user'): 5000, ('user', 'appuse', 'app'): 5000},\n",
" metagraph=[('app', 'user', 'usedby'), ('user', 'app', 'appuse')])"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"user_graph"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['usedby', 'appuse']"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"user_graph.etypes"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"train_eid_dict = {\"appuse\": torch.arange(user_graph.num_edges(\"appuse\")),\n",
" \"usedby\": torch.arange(user_graph.num_edges(\"usedby\")),\n",
" #\"pay\": torch.arange(user_graph.num_edges(\"pay\")),\n",
" #\"payedby\": torch.arange(user_graph.num_edges(\"payedby\"))\n",
" }\n",
"\n",
"sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)\n",
"dataloader = dgl.dataloading.EdgeDataLoader(\n",
" user_graph, train_eid_dict, sampler,\n",
"\n",
" exclude='reverse_types',\n",
" reverse_etypes={'appuse': 'usedby', 'usedby' :'appuse', \n",
" #\"pay\": \"payedby\", \"payedby\": \"pay\",\n",
" #\"follow\": \"followed-by\", \"followed-by\": \"follow\"\n",
" },\n",
" negative_sampler=dgl.dataloading.negative_sampler.Uniform(5),\n",
" batch_size=1024,\n",
" shuffle=True,\n",
" drop_last=False,\n",
" num_workers=4)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('app', 'usedby', 'user'), ('user', 'appuse', 'app')]"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"user_graph.canonical_etypes"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"start\n",
"===blocks\n",
"Block(num_src_nodes={'app': 997, 'user': 997},\n",
" num_dst_nodes={'app': 989, 'user': 994},\n",
" num_edges={('app', 'usedby', 'user'): 4478, ('user', 'appuse', 'app'): 4479},\n",
" metagraph=[('app', 'user', 'usedby'), ('user', 'app', 'appuse')])\n",
"Block(num_src_nodes={'app': 989, 'user': 994},\n",
" num_dst_nodes={'app': 791, 'user': 757},\n",
" num_edges={('app', 'usedby', 'user'): 3528, ('user', 'appuse', 'app'): 3653},\n",
" metagraph=[('app', 'user', 'usedby'), ('user', 'app', 'appuse')])\n",
"===positive\n",
"Graph(num_nodes={'app': 791, 'user': 757},\n",
" num_edges={('app', 'usedby', 'user'): 507, ('user', 'appuse', 'app'): 517},\n",
" metagraph=[('app', 'user', 'usedby'), ('user', 'app', 'appuse')])\n",
"===negative\n",
"Graph(num_nodes={'app': 791, 'user': 757},\n",
" num_edges={('app', 'usedby', 'user'): 507, ('user', 'appuse', 'app'): 517},\n",
" metagraph=[('app', 'user', 'usedby'), ('user', 'app', 'appuse')])\n",
"0 0.9861730337142944\n",
"1 0.98542320728302\n",
"2 0.9932470321655273\n",
"3 0.9804465770721436\n",
"4 0.9751787185668945\n",
"5 0.9605979919433594\n",
"6 0.9659152030944824\n",
"7 0.9588730335235596\n",
"8 0.9457771182060242\n",
"9 0.9311240911483765\n",
"10 0.9498293399810791\n",
"11 0.9599283933639526\n",
"12 0.9256765246391296\n",
"13 0.9242886900901794\n",
"14 0.9002523422241211\n",
"15 0.9019483327865601\n",
"16 0.9321237802505493\n",
"17 0.9701608419418335\n",
"18 0.9530501961708069\n",
"19 0.9441486597061157\n",
"20 0.9451694488525391\n",
"21 0.9319849014282227\n",
"22 0.9293418526649475\n",
"23 0.9047495126724243\n",
"24 0.9070008993148804\n",
"25 0.9395350217819214\n",
"26 0.9406099319458008\n",
"27 0.9003127813339233\n",
"28 0.9299486875534058\n",
"29 0.9051764011383057\n",
"30 0.9139848947525024\n",
"31 0.8974647521972656\n",
"32 0.9230623245239258\n",
"33 0.940125584602356\n",
"34 0.9325525760650635\n",
"35 0.9404858350753784\n",
"36 0.9220913648605347\n",
"37 0.949921190738678\n",
"38 0.9244025945663452\n",
"39 0.9241498708724976\n",
"40 0.9190338253974915\n",
"41 0.9269808530807495\n",
"42 0.8995455503463745\n",
"43 0.9594205617904663\n",
"44 0.9050592184066772\n",
"45 0.9337030649185181\n",
"46 0.9489398002624512\n",
"47 0.8954428434371948\n",
"48 0.9267255663871765\n",
"49 0.9470422267913818\n",
"50 0.9353700876235962\n",
"51 0.9405981302261353\n",
"52 0.8956066370010376\n",
"53 0.932508111000061\n",
"54 0.9318647384643555\n",
"55 0.9398298859596252\n",
"56 0.9388526082038879\n",
"57 0.9342505931854248\n",
"58 0.9066287279129028\n",
"59 0.9176690578460693\n",
"60 0.922206699848175\n",
"61 0.9074845910072327\n",
"62 0.9343910217285156\n",
"63 0.9248802661895752\n",
"64 0.9088714122772217\n",
"65 0.9249729514122009\n",
"66 0.9145314693450928\n",
"67 0.9406793117523193\n",
"68 0.891103208065033\n",
"69 0.9191558361053467\n",
"70 0.8796985149383545\n",
"71 0.9460409879684448\n",
"72 0.889326810836792\n",
"73 0.9306645393371582\n",
"74 0.9206209778785706\n",
"75 0.9006335735321045\n",
"76 0.9094328880310059\n",
"77 0.9488924145698547\n",
"78 0.9614546298980713\n",
"79 0.9205442667007446\n",
"80 0.9324371814727783\n",
"81 0.9419518709182739\n",
"82 0.9413037300109863\n",
"83 0.9125614166259766\n",
"84 0.9074254631996155\n",
"85 0.9183988571166992\n",
"86 0.9366220235824585\n",
"87 0.9232483506202698\n",
"88 0.9524273872375488\n",
"89 0.9082925319671631\n",
"90 0.9233659505844116\n",
"91 0.924187421798706\n",
"92 0.9294463396072388\n",
"93 0.9238454103469849\n",
"94 0.921082615852356\n",
"95 0.9481202960014343\n",
"96 0.9158851504325867\n",
"97 0.8704276084899902\n",
"98 0.9221229553222656\n",
"99 0.8926336765289307\n"
]
}
],
"source": [
"in_features = 10\n",
"hidden_features = 50\n",
"out_features = 10\n",
"num_classes = None\n",
"etypes = user_graph.etypes\n",
"\n",
"model = Model(in_features, hidden_features, out_features, num_classes, etypes)\n",
"#model = model.cuda()\n",
"opt = torch.optim.Adam(model.parameters())\n",
"\n",
"print(\"start\")\n",
"i= 0\n",
"epoch = 100\n",
"for i in range(epoch):\n",
" for j, (input_nodes, positive_graph, negative_graph, blocks) in enumerate(dataloader):\n",
" if i == 0 and j == 0:\n",
" \n",
" print(\"===blocks\")\n",
" print(blocks[0])\n",
" print(blocks[1])\n",
" print(\"===positive\")\n",
" print(positive_graph)\n",
" print(\"===negative\")\n",
" print(negative_graph)\n",
" #blocks = [b.to(torch.device('cuda')) for b in blocks]\n",
" # positive_graph = positive_graph.to(torch.device('cuda'))\n",
" #negative_graph = negative_graph.to(torch.device('cuda'))\n",
" input_features = blocks[0].srcdata['feature']\n",
" \n",
" pos_score, neg_score = model(positive_graph, negative_graph, blocks, input_features)\n",
" #loss = compute_loss(pos_score, neg_score, [('user', 'appuse', 'app')])\n",
" loss = compute_loss(pos_score, neg_score, user_graph.canonical_etypes)\n",
" opt.zero_grad()\n",
" loss.backward()\n",
" opt.step()\n",
" \n",
" print(i, loss.item())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 332,
"metadata": {},
"outputs": [],
"source": [
"\n",
"n_users = 3\n",
"n_apps = 10\n",
"n_shops = 10\n",
"\n",
"n_appuses = 20\n",
"n_pays = 5\n",
"n_hetero_features = 10\n",
"\n",
"appuse_src = np.random.randint(0, n_users, n_appuses)\n",
"appuse_dst = np.random.randint(0, n_apps, n_appuses)\n",
"pay_src = np.random.randint(0, n_users, n_shops)\n",
"pay_dst = np.random.randint(0, n_pays, n_shops)\n",
"\n",
"\n",
"user_graph = dgl.heterograph({\n",
" ('user', 'appuse', 'app'): (appuse_src, appuse_dst),\n",
" ('app', 'usedby', 'user'): (appuse_dst, appuse_src),\n",
" \n",
" # ('user', 'pay', 'shop'): (pay_src, pay_dst),\n",
" #('shop', 'payedby', 'user'): (pay_dst, pay_src)\n",
"})\n",
"\n",
"user_graph.nodes['user'].data['feature'] = torch.randn(n_users, n_hetero_features)\n",
"user_graph.nodes['user'].data['id'] = torch.arange(n_users)\n",
"\n",
"user_graph.nodes['app'].data['feature'] = torch.randn(n_apps, n_hetero_features)\n",
"user_graph.nodes['app'].data['id'] = torch.arange(n_apps)\n",
"\n",
"\n",
"#user_graph.nodes['shop'].data['feature'] = torch.randn(n_shops, n_hetero_features)\n",
"\n",
"#hetero_graph.nodes['user'].data['label'] = torch.randint(0, n_user_classes, (n_users,))\n",
"#hetero_graph.edges['click'].data['label'] = torch.randint(1, n_max_clicks, (n_clicks,)).float()\n",
"# randomly generate training masks on user nodes and click edges\n",
"#hetero_graph.nodes['user'].data['train_mask'] = torch.zeros(n_users, dtype=torch.bool).bernoulli(0.6)\n",
"#hetero_graph.edges['click'].data['train_mask'] = torch.zeros(n_clicks, dtype=torch.bool).bernoulli(0.6)"
]
},
{
"cell_type": "code",
"execution_count": 333,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['app', 'user']"
]
},
"execution_count": 333,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"user_graph.ntypes"
]
},
{
"cell_type": "code",
"execution_count": 334,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Graph(num_nodes={'app': 10, 'user': 3},\n",
" num_edges={('app', 'usedby', 'user'): 20, ('user', 'appuse', 'app'): 20},\n",
" metagraph=[('app', 'user', 'usedby'), ('user', 'app', 'appuse')])"
]
},
"execution_count": 334,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"user_graph"
]
},
{
"cell_type": "code",
"execution_count": 335,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([9]), tensor([1]))"
]
},
"execution_count": 335,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"user_graph.find_edges(1, 'usedby')"
]
},
{
"cell_type": "code",
"execution_count": 336,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([2]), tensor([1]))"
]
},
"execution_count": 336,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"user_graph.find_edges(3, 'usedby')"
]
},
{
"cell_type": "code",
"execution_count": 337,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])"
]
},
"execution_count": 337,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"user_graph.nodes('app')"
]
},
{
"cell_type": "code",
"execution_count": 338,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0, 1, 2])"
]
},
"execution_count": 338,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"user_graph.nodes('user')"
]
},
{
"cell_type": "code",
"execution_count": 339,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"NodeSpace(data={'feature': tensor([[-0.0763, 0.2470, 0.2405, -0.9275, 1.1873, 0.7811, -0.7117, -0.0335,\n",
" 0.6010, -1.5579],\n",
" [-2.9575, 0.1453, 1.3627, 1.2785, -1.4339, 0.8969, -0.2493, 0.3634,\n",
" -0.3548, 0.6159],\n",
" [ 0.2646, 1.0808, 1.6442, -0.0442, 1.2733, 0.6496, 0.4562, -0.2015,\n",
" -0.7876, 0.7678]]), 'id': tensor([0, 1, 2])})"
]
},
"execution_count": 339,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"user_graph.nodes['user']"
]
},
{
"cell_type": "code",
"execution_count": 340,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0, 1, 2])"
]
},
"execution_count": 340,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"user_graph.nodes('user')"
]
},
{
"cell_type": "code",
"execution_count": 345,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'_TYPE': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1]), '_ID': tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2])}"
]
},
"execution_count": 345,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dgl.to_homogeneous(user_graph).ndata"
]
},
{
"cell_type": "code",
"execution_count": 346,
"metadata": {},
"outputs": [],
"source": [
"import networkx as nx\n",
"homo_G = dgl.to_homogeneous(user_graph, ndata=['id'])\n",
"nx_G = homo_G.to_networkx(node_attrs=['id']).to_undirected()\n",
"pos = nx.kamada_kawai_layout(nx_G)"
]
},
{
"cell_type": "code",
"execution_count": 369,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"NodeDataView({0: {'id': tensor(0)}, 1: {'id': tensor(1)}, 2: {'id': tensor(2)}, 3: {'id': tensor(3)}, 4: {'id': tensor(4)}, 5: {'id': tensor(5)}, 6: {'id': tensor(6)}, 7: {'id': tensor(7)}, 8: {'id': tensor(8)}, 9: {'id': tensor(9)}, 10: {'id': tensor(0)}, 11: {'id': tensor(1)}, 12: {'id': tensor(2)}})"
]
},
"execution_count": 369,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nx_G.nodes(data=True)"
]
},
{
"cell_type": "code",
"execution_count": 372,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0, {'id': tensor(0)})\n",
"(1, {'id': tensor(1)})\n",
"(2, {'id': tensor(2)})\n",
"(3, {'id': tensor(3)})\n",
"(4, {'id': tensor(4)})\n",
"(5, {'id': tensor(5)})\n",
"(6, {'id': tensor(6)})\n",
"(7, {'id': tensor(7)})\n",
"(8, {'id': tensor(8)})\n",
"(9, {'id': tensor(9)})\n",
"(10, {'id': tensor(0)})\n",
"(11, {'id': tensor(1)})\n",
"(12, {'id': tensor(2)})\n"
]
}
],
"source": [
"for x in nx_G.nodes(data=True):\n",
" print(x)"
]
},
{
"cell_type": "code",
"execution_count": 377,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"\n",
"cls1color = '#00FFFF'\n",
"cls2color = '#FF00FF'\n",
"colors = [cls1color if n == 0 else cls2color for n in homo_G.ndata[\"_TYPE\"]]\n",
"\n",
"label_dic = {k: int(v[\"id\"]) for k, v in nx_G.nodes(data=True)}\n",
"\n",
"nx.draw(nx_G, pos, with_labels=True, node_color=colors, labels=label_dic)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 390,
"metadata": {},
"outputs": [],
"source": [
"def show_graph(hetero_g):\n",
" homo_G = dgl.to_homogeneous(hetero_g, ndata=['id'])\n",
" nx_G = homo_G.to_networkx(node_attrs=['id']).to_undirected()\n",
" pos = nx.kamada_kawai_layout(nx_G)\n",
" print(hetero_g.nodes(\"user\"))\n",
" print(hetero_g.nodes(\"app\"))\n",
"\n",
" cls1color = '#00FFFF'\n",
" cls2color = '#FF00FF'\n",
" colors = [cls1color if n == 0 else cls2color for n in homo_G.ndata[\"_TYPE\"]]\n",
" \n",
" label_dic = {k: int(v[\"id\"]) for k, v in nx_G.nodes(data=True)}\n",
" \n",
" nx.draw(nx_G, pos, with_labels=True, node_color=colors, \n",
" labels=label_dic\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 391,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([0, 1, 2])\n",
"tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_graph(user_graph)"
]
},
{
"cell_type": "code",
"execution_count": 392,
"metadata": {},
"outputs": [],
"source": [
"train_eid_dict = {\"appuse\": torch.arange(user_graph.num_edges(\"appuse\")),\n",
" \"usedby\": torch.arange(user_graph.num_edges(\"usedby\")),\n",
" #\"pay\": torch.arange(user_graph.num_edges(\"pay\")),\n",
" #\"payedby\": torch.arange(user_graph.num_edges(\"payedby\"))\n",
" }\n",
"\n",
"sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)\n",
"\n",
"dataloader = dgl.dataloading.EdgeDataLoader(\n",
" user_graph, train_eid_dict, sampler,\n",
"\n",
" exclude='reverse_types',\n",
" reverse_etypes={'appuse': 'usedby', 'usedby' :'appuse', \n",
" #\"pay\": \"payedby\", \"payedby\": \"pay\",\n",
" #\"follow\": \"followed-by\", \"followed-by\": \"follow\"\n",
" },\n",
" negative_sampler=dgl.dataloading.negative_sampler.Uniform(1),\n",
" batch_size=4,\n",
" shuffle=True,\n",
" drop_last=False,\n",
" num_workers=4)"
]
},
{
"cell_type": "code",
"execution_count": 393,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<dgl.dataloading.pytorch.EdgeDataLoader at 0x7f9e044ff4d0>"
]
},
"execution_count": 393,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataloader"
]
},
{
"cell_type": "code",
"execution_count": 406,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Graph(num_nodes={'app': 6, 'user': 2},\n",
" num_edges={('app', 'usedby', 'user'): 1, ('user', 'appuse', 'app'): 3},\n",
" metagraph=[('app', 'user', 'usedby'), ('user', 'app', 'appuse')])\n",
"tensor([0, 1])\n",
"tensor([0, 1, 2, 3, 4, 5])\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for i, (input_nodes, positive_graph, negative_graph, blocks) in enumerate(dataloader):\n",
" \n",
" if i == 1:\n",
" break\n",
" \n",
" print(positive_graph)\n",
" show_graph(positive_graph)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 407,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Graph(num_nodes={'app': 4, 'user': 2},\n",
" num_edges={('app', 'usedby', 'user'): 2, ('user', 'appuse', 'app'): 2},\n",
" metagraph=[('app', 'user', 'usedby'), ('user', 'app', 'appuse')])"
]
},
"execution_count": 407,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#print(negative_graph)\n",
"negative_graph\n",
"#show_graph(negative_graph)"
]
},
{
"cell_type": "code",
"execution_count": 408,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Block(num_src_nodes={'app': 9, 'user': 3},\n",
" num_dst_nodes={'app': 8, 'user': 3},\n",
" num_edges={('app', 'usedby', 'user'): 18, ('user', 'appuse', 'app'): 17},\n",
" metagraph=[('app', 'user', 'usedby'), ('user', 'app', 'appuse')]), Block(num_src_nodes={'app': 8, 'user': 3},\n",
" num_dst_nodes={'app': 4, 'user': 2},\n",
" num_edges={('app', 'usedby', 'user'): 12, ('user', 'appuse', 'app'): 9},\n",
" metagraph=[('app', 'user', 'usedby'), ('user', 'app', 'appuse')])]\n"
]
}
],
"source": [
"print(blocks)"
]
},
{
"cell_type": "code",
"execution_count": 410,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Block(num_src_nodes={'app': 9, 'user': 3},\n",
" num_dst_nodes={'app': 8, 'user': 3},\n",
" num_edges={('app', 'usedby', 'user'): 18, ('user', 'appuse', 'app'): 17},\n",
" metagraph=[('app', 'user', 'usedby'), ('user', 'app', 'appuse')])\n"
]
},
{
"ename": "DGLError",
"evalue": "Expect number of features to match number of nodes (len(u)). Got 23 and 24 instead.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mDGLError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-410-e4efec1adcc6>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mblocks\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mshow_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mblocks\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-390-4059b7a7946d>\u001b[0m in \u001b[0;36mshow_graph\u001b[0;34m(hetero_g)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mshow_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhetero_g\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mhomo_G\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdgl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_homogeneous\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhetero_g\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mndata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'id'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mnx_G\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhomo_G\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_networkx\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode_attrs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'id'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_undirected\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mpos\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkamada_kawai_layout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnx_G\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhetero_g\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnodes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"user\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/lib/python3.7/site-packages/dgl/convert.py\u001b[0m in \u001b[0;36mto_homogeneous\u001b[0;34m(G, ndata, edata)\u001b[0m\n\u001b[1;32m 666\u001b[0m \u001b[0mcomb_ef\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcombine_frames\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mG\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_edge_frames\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mG\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0metypes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcol_names\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0medata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 667\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcomb_nf\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 668\u001b[0;31m \u001b[0mretg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcomb_nf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 669\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcomb_ef\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 670\u001b[0m \u001b[0mretg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0medata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcomb_ef\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/lib/python3.7/_collections_abc.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(*args, **kwds)\u001b[0m\n\u001b[1;32m 839\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mother\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMapping\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 840\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mkey\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mother\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 841\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mother\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 842\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mother\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"keys\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 843\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mkey\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mother\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/lib/python3.7/site-packages/dgl/view.py\u001b[0m in \u001b[0;36m__setitem__\u001b[0;34m(self, key, val)\u001b[0m\n\u001b[1;32m 79\u001b[0m \u001b[0;34m'The HeteroNodeDataView has only one node type. '\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 80\u001b[0m \u001b[0;34m'please pass a tensor directly'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 81\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_graph\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_set_n_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_ntid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nodes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mkey\u001b[0m \u001b[0;34m:\u001b[0m \u001b[0mval\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 82\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__delitem__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/lib/python3.7/site-packages/dgl/heterograph.py\u001b[0m in \u001b[0;36m_set_n_repr\u001b[0;34m(self, ntid, u, data)\u001b[0m\n\u001b[1;32m 3807\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mnfeats\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mnum_nodes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3808\u001b[0m raise DGLError('Expect number of features to match number of nodes (len(u)).'\n\u001b[0;32m-> 3809\u001b[0;31m ' Got %d and %d instead.' % (nfeats, num_nodes))\n\u001b[0m\u001b[1;32m 3810\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3811\u001b[0m raise DGLError('Cannot assign node feature \"{}\" on device {} to a graph on'\n",
"\u001b[0;31mDGLError\u001b[0m: Expect number of features to match number of nodes (len(u)). Got 23 and 24 instead."
]
}
],
"source": [
"print(blocks[0])\n",
"show_graph(blocks[0])"
]
},
{
"cell_type": "code",
"execution_count": 280,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Block(num_src_nodes={'app': 9, 'user': 3},\n",
" num_dst_nodes={'app': 6, 'user': 3},\n",
" num_edges={('app', 'usedby', 'user'): 17, ('user', 'appuse', 'app'): 8},\n",
" metagraph=[('app', 'user', 'usedby'), ('user', 'app', 'appuse')])\n",
"{'_TYPE': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3]), '_ID': tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2])}\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"print(blocks[1])\n",
"show_graph(blocks[1])"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"n_users = 1000\n",
"n_apps = 1000\n",
"n_shops = 1000\n",
"\n",
"n_appuses = 5000\n",
"n_pays = 3000\n",
"\n",
"appuse_src = np.random.randint(0, n_users, n_appuses)\n",
"appuse_dst = np.random.randint(0, n_apps, n_appuses)\n",
"pay_src = np.random.randint(0, n_users, n_shops)\n",
"pay_dst = np.random.randint(0, n_shops, n_shops)\n",
"\n",
"user_graph = dgl.heterograph({\n",
" ('user', 'appuse', 'app'): (appuse_src, appuse_dst),\n",
" ('app', 'usedby', 'user'): (appuse_dst, appuse_src),\n",
" ('user', 'pay', 'shop'): (pay_src, pay_dst),\n",
" ('shop', 'payedby', 'user'): (pay_dst, pay_src)\n",
"})\n",
"\n",
"user_graph.nodes['user'].data['feature'] = torch.randn(n_users, n_hetero_features)\n",
"user_graph.nodes['app'].data['feature'] = torch.randn(n_apps, n_hetero_features)\n",
"user_graph.nodes['shop'].data['feature'] = torch.randn(n_shops, n_hetero_features)\n",
"\n",
"#hetero_graph.nodes['user'].data['label'] = torch.randint(0, n_user_classes, (n_users,))\n",
"#hetero_graph.edges['click'].data['label'] = torch.randint(1, n_max_clicks, (n_clicks,)).float()\n",
"# randomly generate training masks on user nodes and click edges\n",
"#hetero_graph.nodes['user'].data['train_mask'] = torch.zeros(n_users, dtype=torch.bool).bernoulli(0.6)\n",
"#hetero_graph.edges['click'].data['train_mask'] = torch.zeros(n_clicks, dtype=torch.bool).bernoulli(0.6)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Graph(num_nodes={'app': 1000, 'shop': 1000, 'user': 1000},\n",
" num_edges={('app', 'usedby', 'user'): 5000, ('shop', 'payedby', 'user'): 1000, ('user', 'appuse', 'app'): 5000, ('user', 'pay', 'shop'): 1000},\n",
" metagraph=[('app', 'user', 'usedby'), ('user', 'app', 'appuse'), ('user', 'shop', 'pay'), ('shop', 'user', 'payedby')])"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"user_graph"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"train_eid_dict = {\"appuse\": torch.arange(user_graph.num_edges(\"appuse\")),\n",
" \"usedby\": torch.arange(user_graph.num_edges(\"usedby\")),\n",
" \"pay\": torch.arange(user_graph.num_edges(\"pay\")),\n",
" \"payedby\": torch.arange(user_graph.num_edges(\"payedby\"))\n",
" }\n",
"\n",
"sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)\n",
"dataloader = dgl.dataloading.EdgeDataLoader(\n",
" user_graph, train_eid_dict, sampler,\n",
" exclude='reverse_types',\n",
" reverse_etypes={'appuse': 'usedby', 'usedby' :'appuse', \n",
" \"pay\": \"payedby\", \"payedby\": \"pay\",\n",
" #\"follow\": \"followed-by\", \"followed-by\": \"follow\"\n",
" },\n",
" negative_sampler=dgl.dataloading.negative_sampler.Uniform(5),\n",
" batch_size=1024,\n",
" shuffle=True,\n",
" drop_last=False,\n",
" num_workers=4)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"start\n",
"0 1.0042177438735962\n",
"1 0.9913958311080933\n",
"2 0.9594790935516357\n",
"3 0.9345883727073669\n",
"4 0.9463564157485962\n",
"5 0.8870338201522827\n",
"6 0.8888425230979919\n",
"7 0.8598660826683044\n",
"8 0.8556907176971436\n",
"9 0.8531943559646606\n",
"10 0.8615914583206177\n",
"11 0.8251054286956787\n",
"12 0.8400945663452148\n",
"13 0.8354456424713135\n",
"14 0.8102348446846008\n",
"15 0.814399778842926\n",
"16 0.802259624004364\n",
"17 0.8318575620651245\n",
"18 0.7923575043678284\n",
"19 0.8325405716896057\n",
"20 0.8271968960762024\n",
"21 0.8054969906806946\n",
"22 0.8147201538085938\n",
"23 0.7913467288017273\n",
"24 0.7965254783630371\n",
"25 0.7753124237060547\n",
"26 0.7986363768577576\n",
"27 0.8199523687362671\n",
"28 0.7809064984321594\n",
"29 0.7893482446670532\n",
"30 0.7893662452697754\n",
"31 0.7680327296257019\n",
"32 0.8147146701812744\n",
"33 0.8083506226539612\n",
"34 0.7754764556884766\n",
"35 0.7904293537139893\n",
"36 0.7884194850921631\n",
"37 0.8081418871879578\n",
"38 0.7823575735092163\n",
"39 0.8116244673728943\n",
"40 0.8163399696350098\n",
"41 0.7872242331504822\n",
"42 0.7838910222053528\n",
"43 0.7678976655006409\n",
"44 0.7785499095916748\n",
"45 0.7835447788238525\n",
"46 0.7548676133155823\n",
"47 0.772226095199585\n",
"48 0.7840331792831421\n",
"49 0.7743468284606934\n",
"50 0.7767171263694763\n",
"51 0.7715997099876404\n",
"52 0.7703453302383423\n",
"53 0.7826955318450928\n",
"54 0.7683325409889221\n",
"55 0.7519842386245728\n",
"56 0.7937618494033813\n",
"57 0.7870801687240601\n",
"58 0.7822519540786743\n",
"59 0.7406740784645081\n",
"60 0.7858375310897827\n",
"61 0.7813568115234375\n",
"62 0.7796692252159119\n",
"63 0.737754762172699\n",
"64 0.7549033761024475\n",
"65 0.7692295908927917\n",
"66 0.7557080388069153\n",
"67 0.7579880952835083\n",
"68 0.7367973327636719\n",
"69 0.770021915435791\n",
"70 0.7784019708633423\n",
"71 0.8126579523086548\n",
"72 0.7770805358886719\n",
"73 0.7531522512435913\n",
"74 0.7973572015762329\n",
"75 0.7904865145683289\n",
"76 0.769869327545166\n",
"77 0.7686647772789001\n",
"78 0.7522403001785278\n",
"79 0.7502648830413818\n",
"80 0.7981303334236145\n",
"81 0.7639758586883545\n",
"82 0.7369908690452576\n",
"83 0.760507345199585\n",
"84 0.7594852447509766\n",
"85 0.7753238081932068\n",
"86 0.8034272789955139\n",
"87 0.7260960340499878\n",
"88 0.783507227897644\n",
"89 0.7923946976661682\n",
"90 0.7648294568061829\n",
"91 0.8086757063865662\n",
"92 0.7398483157157898\n",
"93 0.7571491599082947\n",
"94 0.7814540863037109\n",
"95 0.7343556880950928\n",
"96 0.7834877967834473\n",
"97 0.7532148361206055\n",
"98 0.7671966552734375\n",
"99 0.7945557832717896\n"
]
}
],
"source": [
"in_features = 10\n",
"hidden_features = 50\n",
"out_features = 10\n",
"num_classes = None\n",
"etypes = user_graph.etypes\n",
"\n",
"model = Model(in_features, hidden_features, out_features, num_classes, etypes)\n",
"#model = model.cuda()\n",
"opt = torch.optim.Adam(model.parameters())\n",
"\n",
"print(\"start\")\n",
"i= 0\n",
"epoch = 100\n",
"for i in range(epoch):\n",
" for input_nodes, positive_graph, negative_graph, blocks in dataloader:\n",
"\n",
" #blocks = [b.to(torch.device('cuda')) for b in blocks]\n",
" # positive_graph = positive_graph.to(torch.device('cuda'))\n",
" #negative_graph = negative_graph.to(torch.device('cuda'))\n",
" input_features = blocks[0].srcdata['feature']\n",
" \n",
" pos_score, neg_score = model(positive_graph, negative_graph, blocks, input_features)\n",
" loss = compute_loss(pos_score, neg_score, user_graph.canonical_etypes)\n",
" opt.zero_grad()\n",
" loss.backward()\n",
" opt.step()\n",
" \n",
" print(i, loss.item())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment