Skip to content

Instantly share code, notes, and snippets.

@ilblackdragon
Created August 8, 2017 04:43
Show Gist options
  • Save ilblackdragon/779938a95b8d90f30cd94ff38eb1e538 to your computer and use it in GitHub Desktop.
Save ilblackdragon/779938a95b8d90f30cd94ff38eb1e538 to your computer and use it in GitHub Desktop.
import collections
import torch
from torch.autograd import Variable
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import utils
class Fold(object):
class Node(object):
def __init__(self, op, step, index, *args):
self.op = op
self.step = step
self.index = index
self.args = args
self.split_idx = -1
def split(self, num):
"""Split resulting node, if function returns multiple values."""
nodes = []
for idx in range(num):
nodes.append(Fold.Node(
self.op, self.step, self.index, *self.args))
nodes[-1].split_idx = idx
return nodes
def __repr__(self):
return "[%d:%d]%s" % (
self.step, self.index, self.op)
def __init__(self):
self.steps = collections.defaultdict(
lambda: collections.defaultdict(list))
self.cached_nodes = collections.defaultdict(dict)
self.total_nodes = 0
def add(self, op, *args):
"""Add op to the fold."""
self.total_nodes += 1
if args not in self.cached_nodes[op]:
step = max([0] + [arg.step + 1 for arg in args
if isinstance(arg, Fold.Node)])
node = Fold.Node(op, step, len(self.steps[step][op]), *args)
self.steps[step][op].append(args)
self.cached_nodes[op][args] = node
return self.cached_nodes[op][args]
def _batch_args(self, arg_lists, values):
res = []
for arg in arg_lists:
r = []
if isinstance(arg[0], Fold.Node):
for x in arg:
if x.split_idx >= 0:
r.append(values[x.step][x.op][x.split_idx][x.index])
else:
r.append(values[x.step][x.op][x.index])
res.append(torch.cat(r, 0))
else:
res.append(Variable(torch.LongTensor(arg)))
return res
def apply(self, nn, nodes):
"""Apply current fold to given neural module."""
values = {}
for step in sorted(self.steps.keys()):
values[step] = {}
for op in self.steps[step]:
func = getattr(nn, op)
batched_args = self._batch_args(
zip(*self.steps[step][op]), values)
res = func(*batched_args)
if isinstance(res, (tuple, list)):
values[step][op] = []
for x in res:
values[step][op].append(
torch.chunk(x, batched_args[0].size()[0])
)
else:
values[step][op] = torch.chunk(
res, batched_args[0].size()[0])
return self._batch_args(nodes, values)
if __name__ == "__main__":
timer = utils.Timer()
timer.start()
f = Fold()
v1, _ = f.add('value', 1).split(2)
v2, _ = f.add('value', 2).split(2)
r = v1
for i in range(1000):
r = f.add('attr', v1, v2)
r = f.add('attr', r, v2)
timer.tag('fold')
class TestEncoder(nn.Module):
def __init__(self):
super(TestEncoder, self).__init__()
self.embed = nn.Embedding(10, 10)
self.out = nn.Linear(20, 10)
def value(self, idx):
return self.embed(idx), self.embed(idx)
def attr(self, left, right):
return self.out(torch.cat([left, right], 1))
te = TestEncoder()
timer.tag('encoder: created')
enc = f.apply(te, [[r]])
timer.tag('encoder: apply')
print(enc[0].size())
print(timer.report())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment