Skip to content

Instantly share code, notes, and snippets.

@r9y9
Created June 27, 2017 11:53
Show Gist options
  • Save r9y9/d3d958458343b27ff5adb6dd160d3a4e to your computer and use it in GitHub Desktop.
Save r9y9/d3d958458343b27ff5adb6dd160d3a4e to your computer and use it in GitHub Desktop.
Faster version of __construct_weight_matrix in https://gist.github.com/r9y9/88bda659c97f46f42525#file-gmmmap-py
import scipy.sparse
import numpy as np
# From my previous code
def __construct_weight_matrix(T, D):
# Construct Weight matrix W
# Eq.(25) ~ (28)
tmp = np.zeros(D)
for t in range(T):
w0 = scipy.sparse.lil_matrix((D, D * T))
w1 = scipy.sparse.lil_matrix((D, D * T))
w0[0:, t * D:(t + 1) * D] = scipy.sparse.diags(np.ones(D), 0)
if t - 1 >= 0:
tmp.fill(-0.5)
w1[0:, (t - 1) * D:t * D] = scipy.sparse.diags(tmp, 0)
if t + 1 < T:
tmp.fill(0.5)
w1[0:, (t + 1) * D:(t + 2) * D] = scipy.sparse.diags(tmp, 0)
W_t = scipy.sparse.vstack([w0, w1])
# Slower
# self.W[2*D*t:2*D*(t+1),:] = W_t
if t == 0:
W = W_t
else:
W = scipy.sparse.vstack([W, W_t])
W = scipy.sparse.csr_matrix(W)
return W
# Faster one
def faster_w(T, D):
# Construct Weight matrix W
# Eq.(25) ~ (28)
data = []
indices = []
indptr = [0]
cum = 1
for t in range(T):
data.extend(np.ones(D))
indices.extend(np.arange(t * D, (t + 1) * D))
indptr.extend(np.arange(cum, cum + D))
cum += D
if t == 0:
data.extend(np.ones(D) * 0.5)
indices.extend(np.arange((t + 1) * D, (t + 2) * D))
elif t == T - 1:
data.extend(np.ones(D) * -0.5)
indices.extend(np.arange((t - 1) * D, t * D))
else:
d = np.empty(2 * D)
d[0::2] = np.ones(D) * -0.5
d[1::2] = np.ones(D) * 0.5
ind = np.empty(2 * D)
ind[0::2] = np.arange((t - 1) * D, t * D)
ind[1::2] = np.arange((t + 1) * D, (t + 2) * D)
data.extend(d)
indices.extend(ind)
if t == 0 or t == T - 1:
indptr.extend(np.arange(cum, cum + D))
cum += D
else:
# indptr.extend(np.arange(cum, cum + D))
indptr.extend(np.arange(cum + 1, cum + 1 + D * 2, 2))
cum += 2 * D
W = scipy.sparse.csr_matrix(
(data, indices, indptr), shape=(2 * D * T, D * T))
return W
if __name__ == "__main__":
import time
since = time.time()
W = __construct_weight_matrix(1000, 25)
print("Elapsed time: {}".format(time.time() - since))
since = time.time()
W_new = faster_w(1000, 25)
print("Elapsed time: {}".format(time.time() - since))
assert W.shape == W_new.shape
assert type(W) == type(W_new)
assert np.all(W.data == W_new.data)
assert np.all(W.indices == W_new.indices)
assert np.all(W.indptr == W_new.indptr)
@r9y9
Copy link
Author

r9y9 commented Jun 27, 2017

 python a.py                                                                                                                         ST 2   new-world 
Elapsed time: 1.34931612015
Elapsed time: 0.0376768112183

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment