Skip to content

Instantly share code, notes, and snippets.

@bnsh
Created August 23, 2020 12:09
Show Gist options
  • Save bnsh/65d5fc1582d7559a8f868a9904cf7a0e to your computer and use it in GitHub Desktop.
Save bnsh/65d5fc1582d7559a8f868a9904cf7a0e to your computer and use it in GitHub Desktop.
#! /usr/bin/env python3
# vim: expandtab shiftwidth=4 tabstop=4
"""This program uses an xor network to test MyDataParallel"""
import argparse
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.optim as optim
class MyDataParallel(nn.DataParallel):
def __getattr__(self, name):
return getattr(self.module, name)
def create_network(hiddensz):
return nn.Sequential(OrderedDict([
("input", nn.Linear(2, hiddensz)),
("tanh", nn.Tanh()),
("logits", nn.Linear(hiddensz, 1))
]))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--parallelize", action="store_true")
args = parser.parse_args()
if args.parallelize:
network = MyDataParallel(create_network(5))
else:
network = create_network(5)
crit = nn.BCEWithLogitsLoss()
opt = optim.Adamax(network.parameters())
inputs = torch.FloatTensor([
[0, 0],
[0, 1],
[1, 0],
[1, 1],
])
targets = torch.FloatTensor([
[0],
[1],
[1],
[0],
])
# Let's just duplicate inputs over and over, just so parallelization makes any kind of sense.
inputs = inputs.repeat([1024, 1])
targets = targets.repeat([1024, 1])
epoch = 0
while True:
opt.zero_grad()
out = network(inputs)
loss = crit(out, targets)
loss.backward()
opt.step()
epoch += 1
print("%d: %.7f" % (epoch, loss))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment