Created
January 22, 2021 21:22
-
-
Save ZolotukhinM/bb7d5c5fcd483739ac706d326e171b06 to your computer and use it in GitHub Desktop.
3 Versions of DeepAndWide model: TorchScript, StaticRuntime, and NNC
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import os | |
from timeit import default_timer as timer | |
class DeepAndWide(torch.nn.Module): | |
def __init__(self, num_features=50): | |
super(DeepAndWide, self).__init__() | |
self.mu = torch.randn(1, num_features) | |
self.sigma = torch.randn(1, num_features) | |
self.fc_w = torch.randn(1, num_features + 1) | |
self.fc_b = torch.randn(1) | |
def forward(self, ad_emb_packed, user_emb, wide): | |
wide_offset = wide + self.mu | |
wide_normalized = wide_offset * self.sigma | |
wide_preproc = torch.clamp(wide_normalized, 0., 10.) | |
user_emb_t = torch.transpose(user_emb, 1, 2) | |
dp_unflatten = torch.bmm(ad_emb_packed, user_emb_t) | |
dp = torch.flatten(dp_unflatten, 1, -1) | |
inp = torch.cat([dp, wide_preproc], 1) | |
fc1 = torch.addmm(self.fc_b, inp, torch.t(self.fc_w), beta=1, alpha=1) | |
return torch.sigmoid(fc1) | |
class DeepAndWideCompiled(torch.nn.Module): | |
def __init__(self, mu, sigma, fc_w, fc_b, num_features=50, batch_size=1, embedding_size=32): | |
super(DeepAndWideCompiled, self).__init__() | |
self.mu = mu | |
self.sigma = sigma | |
self.fc_w = fc_w | |
self.fc_b = fc_b | |
self.scope = torch._C.te.KernelScope() | |
self.compile_with_nnc(num_features, batch_size, embedding_size) | |
def compile_with_nnc(self, num_features, batch_size, embedding_size): | |
def get_dim_args(dims): | |
dim_args = [] | |
for dim in dims: | |
dim_args.append(torch._C.te.DimArg(dim, 'i' + str(len(dim_args)))) | |
return dim_args | |
# Placeholders | |
ZERO = torch._C.te.ExprHandle.int(0) | |
ONE = torch._C.te.ExprHandle.int(1) | |
FZERO = torch._C.te.ExprHandle.float(0.) | |
FTEN = torch._C.te.ExprHandle.float(10.) | |
BS = torch._C.te.ExprHandle.int(batch_size) | |
ES = torch._C.te.ExprHandle.int(embedding_size) | |
NF = torch._C.te.ExprHandle.int(num_features) | |
NF1 = torch._C.te.ExprHandle.int(num_features + 1) | |
dtype = torch._C.te.Dtype.Float | |
MU = torch._C.te.Placeholder('MU', dtype, [ONE, NF]) | |
SIGMA = torch._C.te.Placeholder('SIGMA', dtype, [ONE, NF]) | |
FC_W = torch._C.te.Placeholder('FC_W', dtype, [ONE, NF1]) | |
FC_B = torch._C.te.Placeholder('FC_B', dtype, [ONE]) | |
AD_EMB = torch._C.te.Placeholder('AD_EMB', dtype, [BS, ONE, ES]) | |
USER_EMB = torch._C.te.Placeholder('USER_EMB', dtype, [BS, ONE, ES]) | |
WIDE = torch._C.te.Placeholder('WIDE', dtype, [BS, NF]) | |
# Computation itself | |
wide_offset = torch._C.te.Compute('wide_offset', get_dim_args([BS, NF]), | |
lambda i, j: WIDE.load([i, j]) + MU.load([ZERO, j])) | |
wide_norm = torch._C.te.Compute('wide_norm', get_dim_args([BS, NF]), | |
lambda i, j: wide_offset.load([i, j]) * SIGMA.load([ZERO, j])) | |
wide_preproc = torch._C.te.Compute('wide_preproc', get_dim_args([BS, NF]), | |
lambda i, j: torch._C.te.ifThenElse(wide_norm.load([i, j]) < FZERO, | |
FZERO, | |
torch._C.te.ifThenElse(wide_norm.load([i, j]) > FTEN, | |
FTEN, | |
wide_norm.load([i, j])))) | |
user_emb_t = torch._C.te.Compute('user_emb_t', get_dim_args([BS, ES, ONE]), | |
lambda i, j, k: USER_EMB.load([i, k, j])) | |
dp_terms = torch._C.te.Compute('dp_terms', get_dim_args([BS, ONE, ES]), | |
lambda i, _, j: AD_EMB.load([i, ZERO, j]) * user_emb_t.load([i, j, ZERO])) | |
dp = torch._C.te.SumReduce('dp', get_dim_args([BS, ONE]), dp_terms, get_dim_args([ES])) | |
inp = torch._C.te.Compute('inp', get_dim_args([BS, NF1]), | |
lambda i, j: torch._C.te.ifThenElse(j < ONE, dp.load([i, j]), wide_preproc.load([i, j - ONE]))) | |
fc_w_t = torch._C.te.Compute('fc_w_t', get_dim_args([NF1, ONE]), | |
lambda i, j: FC_W.load([j, i])) | |
mm_terms = torch._C.te.Compute('mm_terms', get_dim_args([BS, ONE, NF1]), | |
lambda i, j, k: inp.load([i, k]) * fc_w_t.load([k, j])) | |
mm = torch._C.te.SumReduce('mm', get_dim_args([BS, ONE]), mm_terms, get_dim_args([NF1])) | |
fc1 = torch._C.te.Compute('fc1', get_dim_args([BS, ONE]), | |
lambda i, j: mm.load([i, j]) + FC_B.load([j])) | |
X = torch._C.te.Compute('sigmoid', get_dim_args([BS, ONE]), | |
lambda i, j: torch._C.te.ExprHandle.sigmoid(fc1.load([i, j]))) | |
# Generating TE statement | |
loopnest = torch._C.te.LoopNest([X]) | |
loopnest.prepare_for_codegen() | |
stmt = torch._C.te.simplify(loopnest.root_stmt()) | |
# Compiling | |
self.codegen = torch._C.te.construct_codegen('llvm', stmt, [torch._C.te.BufferArg(x) for x in [AD_EMB, USER_EMB, WIDE, MU, SIGMA, FC_W, FC_B, X]]) | |
def forward(self, ad_emb_packed, user_emb, wide): | |
result = torch.empty(batch_size, 1) | |
self.codegen.call([ad_emb_packed, user_emb, wide, self.mu, self.sigma, self.fc_w, self.fc_b, result]) | |
return result | |
if __name__ == "__main__": | |
num_features = 50 | |
# Phabricate sample inputs | |
batch_size = 1 | |
embedding_size = 32 | |
ad_emb_packed = torch.randn(batch_size, 1, embedding_size) | |
user_emb = torch.randn(batch_size, 1, embedding_size) | |
wide = torch.randn(batch_size, num_features) | |
inps = (ad_emb_packed, user_emb, wide) | |
warmup = 10 | |
iters = 20000 | |
# TorchScript JIT version | |
m = torch.jit.script(DeepAndWide(num_features)) | |
m.eval() | |
for _ in range(warmup): | |
y = m(*inps) | |
start = timer() | |
for _ in range(iters): | |
y = m(*inps) | |
end = timer() | |
print('TorchScript: %.3fs' % (end - start)) | |
print('Result:', y) | |
# Static Runtime version | |
static_runtime = torch._C._jit_to_static_runtime(m._c) | |
for _ in range(warmup): | |
y = static_runtime.run(inps) | |
start = timer() | |
for _ in range(iters): | |
y = static_runtime.run(inps) | |
end = timer() | |
print('Static runtime: %.3fs' % (end - start)) | |
print('Result:', y) | |
# NNC-compiled version | |
nnc = DeepAndWideCompiled(m.mu, m.sigma, m.fc_w, m.fc_b, num_features, batch_size, embedding_size) | |
for _ in range(warmup): | |
y = nnc(*inps) | |
start = timer() | |
for _ in range(iters): | |
y = nnc(*inps) | |
end = timer() | |
print('Compiled with NNC: %.3fs' % (end - start)) | |
print('Result:', y) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment