Skip to content

Instantly share code, notes, and snippets.

@ZolotukhinM
Created January 22, 2021 21:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ZolotukhinM/bb7d5c5fcd483739ac706d326e171b06 to your computer and use it in GitHub Desktop.
Save ZolotukhinM/bb7d5c5fcd483739ac706d326e171b06 to your computer and use it in GitHub Desktop.
3 Versions of DeepAndWide model: TorchScript, StaticRuntime, and NNC
import torch
import os
from timeit import default_timer as timer
class DeepAndWide(torch.nn.Module):
def __init__(self, num_features=50):
super(DeepAndWide, self).__init__()
self.mu = torch.randn(1, num_features)
self.sigma = torch.randn(1, num_features)
self.fc_w = torch.randn(1, num_features + 1)
self.fc_b = torch.randn(1)
def forward(self, ad_emb_packed, user_emb, wide):
wide_offset = wide + self.mu
wide_normalized = wide_offset * self.sigma
wide_preproc = torch.clamp(wide_normalized, 0., 10.)
user_emb_t = torch.transpose(user_emb, 1, 2)
dp_unflatten = torch.bmm(ad_emb_packed, user_emb_t)
dp = torch.flatten(dp_unflatten, 1, -1)
inp = torch.cat([dp, wide_preproc], 1)
fc1 = torch.addmm(self.fc_b, inp, torch.t(self.fc_w), beta=1, alpha=1)
return torch.sigmoid(fc1)
class DeepAndWideCompiled(torch.nn.Module):
def __init__(self, mu, sigma, fc_w, fc_b, num_features=50, batch_size=1, embedding_size=32):
super(DeepAndWideCompiled, self).__init__()
self.mu = mu
self.sigma = sigma
self.fc_w = fc_w
self.fc_b = fc_b
self.scope = torch._C.te.KernelScope()
self.compile_with_nnc(num_features, batch_size, embedding_size)
def compile_with_nnc(self, num_features, batch_size, embedding_size):
def get_dim_args(dims):
dim_args = []
for dim in dims:
dim_args.append(torch._C.te.DimArg(dim, 'i' + str(len(dim_args))))
return dim_args
# Placeholders
ZERO = torch._C.te.ExprHandle.int(0)
ONE = torch._C.te.ExprHandle.int(1)
FZERO = torch._C.te.ExprHandle.float(0.)
FTEN = torch._C.te.ExprHandle.float(10.)
BS = torch._C.te.ExprHandle.int(batch_size)
ES = torch._C.te.ExprHandle.int(embedding_size)
NF = torch._C.te.ExprHandle.int(num_features)
NF1 = torch._C.te.ExprHandle.int(num_features + 1)
dtype = torch._C.te.Dtype.Float
MU = torch._C.te.Placeholder('MU', dtype, [ONE, NF])
SIGMA = torch._C.te.Placeholder('SIGMA', dtype, [ONE, NF])
FC_W = torch._C.te.Placeholder('FC_W', dtype, [ONE, NF1])
FC_B = torch._C.te.Placeholder('FC_B', dtype, [ONE])
AD_EMB = torch._C.te.Placeholder('AD_EMB', dtype, [BS, ONE, ES])
USER_EMB = torch._C.te.Placeholder('USER_EMB', dtype, [BS, ONE, ES])
WIDE = torch._C.te.Placeholder('WIDE', dtype, [BS, NF])
# Computation itself
wide_offset = torch._C.te.Compute('wide_offset', get_dim_args([BS, NF]),
lambda i, j: WIDE.load([i, j]) + MU.load([ZERO, j]))
wide_norm = torch._C.te.Compute('wide_norm', get_dim_args([BS, NF]),
lambda i, j: wide_offset.load([i, j]) * SIGMA.load([ZERO, j]))
wide_preproc = torch._C.te.Compute('wide_preproc', get_dim_args([BS, NF]),
lambda i, j: torch._C.te.ifThenElse(wide_norm.load([i, j]) < FZERO,
FZERO,
torch._C.te.ifThenElse(wide_norm.load([i, j]) > FTEN,
FTEN,
wide_norm.load([i, j]))))
user_emb_t = torch._C.te.Compute('user_emb_t', get_dim_args([BS, ES, ONE]),
lambda i, j, k: USER_EMB.load([i, k, j]))
dp_terms = torch._C.te.Compute('dp_terms', get_dim_args([BS, ONE, ES]),
lambda i, _, j: AD_EMB.load([i, ZERO, j]) * user_emb_t.load([i, j, ZERO]))
dp = torch._C.te.SumReduce('dp', get_dim_args([BS, ONE]), dp_terms, get_dim_args([ES]))
inp = torch._C.te.Compute('inp', get_dim_args([BS, NF1]),
lambda i, j: torch._C.te.ifThenElse(j < ONE, dp.load([i, j]), wide_preproc.load([i, j - ONE])))
fc_w_t = torch._C.te.Compute('fc_w_t', get_dim_args([NF1, ONE]),
lambda i, j: FC_W.load([j, i]))
mm_terms = torch._C.te.Compute('mm_terms', get_dim_args([BS, ONE, NF1]),
lambda i, j, k: inp.load([i, k]) * fc_w_t.load([k, j]))
mm = torch._C.te.SumReduce('mm', get_dim_args([BS, ONE]), mm_terms, get_dim_args([NF1]))
fc1 = torch._C.te.Compute('fc1', get_dim_args([BS, ONE]),
lambda i, j: mm.load([i, j]) + FC_B.load([j]))
X = torch._C.te.Compute('sigmoid', get_dim_args([BS, ONE]),
lambda i, j: torch._C.te.ExprHandle.sigmoid(fc1.load([i, j])))
# Generating TE statement
loopnest = torch._C.te.LoopNest([X])
loopnest.prepare_for_codegen()
stmt = torch._C.te.simplify(loopnest.root_stmt())
# Compiling
self.codegen = torch._C.te.construct_codegen('llvm', stmt, [torch._C.te.BufferArg(x) for x in [AD_EMB, USER_EMB, WIDE, MU, SIGMA, FC_W, FC_B, X]])
def forward(self, ad_emb_packed, user_emb, wide):
result = torch.empty(batch_size, 1)
self.codegen.call([ad_emb_packed, user_emb, wide, self.mu, self.sigma, self.fc_w, self.fc_b, result])
return result
if __name__ == "__main__":
num_features = 50
# Phabricate sample inputs
batch_size = 1
embedding_size = 32
ad_emb_packed = torch.randn(batch_size, 1, embedding_size)
user_emb = torch.randn(batch_size, 1, embedding_size)
wide = torch.randn(batch_size, num_features)
inps = (ad_emb_packed, user_emb, wide)
warmup = 10
iters = 20000
# TorchScript JIT version
m = torch.jit.script(DeepAndWide(num_features))
m.eval()
for _ in range(warmup):
y = m(*inps)
start = timer()
for _ in range(iters):
y = m(*inps)
end = timer()
print('TorchScript: %.3fs' % (end - start))
print('Result:', y)
# Static Runtime version
static_runtime = torch._C._jit_to_static_runtime(m._c)
for _ in range(warmup):
y = static_runtime.run(inps)
start = timer()
for _ in range(iters):
y = static_runtime.run(inps)
end = timer()
print('Static runtime: %.3fs' % (end - start))
print('Result:', y)
# NNC-compiled version
nnc = DeepAndWideCompiled(m.mu, m.sigma, m.fc_w, m.fc_b, num_features, batch_size, embedding_size)
for _ in range(warmup):
y = nnc(*inps)
start = timer()
for _ in range(iters):
y = nnc(*inps)
end = timer()
print('Compiled with NNC: %.3fs' % (end - start))
print('Result:', y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment