Skip to content

Instantly share code, notes, and snippets.

@qianyizhang
Created July 11, 2019 09:13
Show Gist options
  • Save qianyizhang/2fbd0c72024dd0a0fd37cd458aa8ee8f to your computer and use it in GitHub Desktop.
Save qianyizhang/2fbd0c72024dd0a0fd37cd458aa8ee8f to your computer and use it in GitHub Desktop.
replacement of scaled_l2 and aggregate in PyTorch-Encoding/encoding/functions/encoding.py with pure torch ops
def scaled_l2(X, C, S):
"""
scaled_l2 distance
Args:
X (b*n*d): original feature input
C (k*d): code words, with k codes, each with d dimension
S (k): scale cofficient
Return:
D (b*n*k): relative distance to each code
Note:
apparently the X^2 + C^2 - 2XC computation is 2x faster than
elementwise sum, perhaps due to friendly cache in gpu
"""
assert X.shape[-1] == C.shape[-1], "input, codeword feature dim mismatch"
assert S.numel() == C.shape[0], "scale, codeword num mismatch"
"""
# simplier but slower
X = X.unsqueeze(2)
C = C[None, None,...]
norm = torch.norm(X-C, dim=-1).pow(2.0)
scaled_norm = S * norm
"""
b, n, d = X.shape
X = X.view(-1, d) # [bn, d]
Ct = C.t() # [d, k]
X2 = X.pow(2.0).sum(-1, keepdim=True) # [bn, 1]
C2 = Ct.pow(2.0).sum(0, keepdim=True) # [1, k]
norm = X2 + C2 - 2.0 * X.mm(Ct) # [bn, k]
scaled_norm = S * norm
D = scaled_norm.view(b, n, -1) # [b, n, k]
return D
def aggregate(A, X, C):
"""
aggregate residuals from N samples
Args:
A (b*n*k): weight of each feature contribute to code residual
X (b*n*d): original feature input
C (k*d): code words, with k codes, each with d dimension
Return:
E (b*k*d): residuals to each code
"""
assert X.shape[-1] == C.shape[-1], "input, codeword feature dim mismatch"
assert A.shape[:2] == X.shape[:2], "weight, input dim mismatch"
X = X.unsqueeze(2) # [b, n, d] -> [b, n, 1, d]
C = C[None, None, ...] # [k, d] -> [1, 1, k, d]
A = A.unsqueeze(-1) # [b, n, k] -> [b, n, k, 1]
R = (X - C) * A # [b, n, k, d]
E = R.sum(dim=1) # [b, k, d]
return E
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment