Skip to content

Instantly share code, notes, and snippets.

@justheuristic
Created March 2, 2020 14:27
Show Gist options
  • Save justheuristic/01d5ffe9c534d90e40badff35653ba7d to your computer and use it in GitHub Desktop.
Save justheuristic/01d5ffe9c534d90e40badff35653ba7d to your computer and use it in GitHub Desktop.
import sys
import torch
import torch.nn as nn
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor
sys.path.append('../')
import lib.client
N_EXPERTS = 32
experts = [
lib.client.RemoteExpert('inp', port=8099),
*[lib.client.RemoteExpert(f'expert{i}', port=8099) for i in range(2)],
lib.client.RemoteExpert('out', port=8099)
]
dataset = MNIST('.', download=True, transform=Compose([ToTensor(), lambda tensor: tensor.view(-1)]))
loader = torch.utils.data.DataLoader(dataset, batch_size=256)
network = nn.Sequential(*experts)
while True:
for x, y in loader:
output = network(x)
loss = torch.nn.functional.cross_entropy(output, y)
print(loss.item(), (output.argmax(dim=-1) == y).float().mean().item())
loss.backward()
import sys
import torch
import torch.nn as nn
sys.path.append('../')
import lib
experts = {}
expert = nn.Sequential(
nn.Linear(784, 512), nn.ReLU(inplace=True),
)
expert_backend = lib.ExpertBackend(name='inp',
expert=torch.jit.script(expert), opt=torch.optim.Adam(expert.parameters()),
args_schema=(lib.BatchTensorProto(784),),
max_batch_size=8192)
experts['inp'] = expert_backend
for i in range(32):
expert = nn.Sequential(
nn.Linear(512, 2048), nn.ReLU(inplace=True),
nn.Linear(2048, 2048), nn.ReLU(inplace=True),
nn.Linear(2048, 512), nn.ReLU(inplace=True),
)
expert_backend = lib.ExpertBackend(name=f'expert{i}',
expert=torch.jit.script(expert), opt=torch.optim.Adam(expert.parameters()),
args_schema=(lib.BatchTensorProto(512),),
max_batch_size=8192)
experts[f'expert{i}'] = expert_backend
expert = nn.Sequential(nn.Linear(512, 10))
expert_backend = lib.ExpertBackend(name='out',
expert=torch.jit.script(expert), opt=torch.optim.Adam(expert.parameters()),
args_schema=(lib.BatchTensorProto(512),),
max_batch_size=8192, )
experts['out'] = expert_backend
lib.TesseractServer(None, experts, port=8099, conn_handler_processes=4, sender_threads=1,
device=torch.device('cuda'),
start=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment