Skip to content

Instantly share code, notes, and snippets.

@DIYer22
Created March 14, 2019 07:42
Show Gist options
  • Save DIYer22/0bb621cea2d817b3ce85a610b73a5a55 to your computer and use it in GitHub Desktop.
Save DIYer22/0bb621cea2d817b3ce85a610b73a5a55 to your computer and use it in GitHub Desktop.
赵振宇的 numpy 加速求每一对向量间的 L2
from boxx import *
import numpy as np
def zzy(M):
# M = np.array([[1,1],[2,2],[3,3]])
a2 = np.sum(M*M, axis = 1).reshape(1,-1)
b2 = a2.T
ab = M.dot(M.T)
res = a2 - 2*ab + b2
g()
return res**.5
ma = randomm((1000,128), 10)/15
#ma = M
#ma = np.array([[0, 1],
# [0, 2],
# [0, 1]])
h, w = ma.shape
# target: ((a-b)**2).sum()**.5
def old(ma):
l2 = np.zeros((h,h))
for i in range(h-1):
for j in range(i+1, h):
l2[j, i] = l2[i, j] = ((ma[i]-ma[j])**2).sum()**.5
return l2
# (a-b)**2 = a**2 - 2ab + b**2
def mynew(ma):
ab = np.matmul(ma.T[...,None], ma.T[...,None,:])
poww = ma**2
a_b2 = poww.T[...,None] + poww.T[...,None,:] - 2*ab
l2 = (a_b2).sum(0)**.5
# g()
return l2
with timeit('old'):
l21 = old(ma)
with timeit('new'):
l22 = mynew(ma)
with timeit('zzy'):
l23 = zzy(ma)
print('diff:', (np.abs((l21 - l22))).sum())
print('diff2:', (np.abs((l21 - l23))).sum())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment