Skip to content

Instantly share code, notes, and snippets.

@botev
Created September 21, 2017 02:42
Show Gist options
  • Save botev/c3196335d0afcbb46220ca3f021ca448 to your computer and use it in GitHub Desktop.
Save botev/c3196335d0afcbb46220ca3f021ca448 to your computer and use it in GitHub Desktop.
Scipy LBFGS
import os
import time
import numpy as np
import theano
import theano.tensor as T
from scipy.optimize import fmin_l_bfgs_b
def scipy():
path = "."
theano.config.floatX = "float64"
x = np.loadtxt(os.path.join(path, "x.csv"), delimiter=",")
x = theano.shared(x, name="x")
y = np.loadtxt(os.path.join(path, "y.csv"), delimiter=",")
y = theano.shared(y, name="y")
n, d = x.shape.eval()
w_all = np.loadtxt(os.path.join(path, "w.csv"), delimiter=",")
w = T.dvector()
W1 = T.reshape(w[:d * d], (d, d)).T
W2 = T.reshape(w[-d:], (d, 1))
h = T.tanh(T.dot(x, W1))
pred = T.dot(h, W2).flatten()
f = T.zeros_like(y)
i0 = T.eq(y, 0).nonzero()
i1 = T.eq(y, 1).nonzero()
f = T.set_subtensor(f[i1], - T.nnet.softplus(-pred[i1]))
f = T.set_subtensor(f[i0], - T.nnet.softplus(pred[i0]))
f = - T.mean(f)
g = T.grad(f, w)
func = theano.function([w], [f, g])
t = time.time()
x, f, d = fmin_l_bfgs_b(func, w_all[:, 0], m=100, pgtol=1e-9, maxls=25, factr=10)
print("Time:", time.time() - t)
print("Final solution:")
print(x)
print("Final optimum:")
print(f)
print("True minimum")
print(func(w_all[:, -1])[0])
print("Stopping criteria:")
print(d["task"])
if __name__ == '__main__':
scipy()
-0.10224 0.053878 0.73074 0.40254 0.4744 0.69503 0.79161 1.5217 1.8335 1.9425 2.0424 2.1831 2.9329 2.9685 3.1143 3.9597 4.174 4.5484 4.8624 5.2055 5.5404
-0.24145 -0.49594 -1.0375 -1.2273 -1.2093 -1.2441 -1.3338 -2.0521 -2.4791 -2.7628 -3.1358 -3.5747 -5.8406 -5.7088 -5.6764 -7.5494 -7.9538 -8.6604 -9.2517 -9.8955 -10.522
0.31921 0.29676 0.23828 0.36477 0.4113 0.70706 1.0074 2.976 3.6505 3.7452 3.6937 3.6914 3.8137 4.0641 4.5474 5.3862 5.6607 6.1288 6.5108 6.9186 7.3092
0.31286 0.27309 0.23822 0.22449 0.22564 0.21579 0.188 -0.24867 -0.41676 -0.4008 -0.34268 -0.22448 0.63744 0.47189 0.22938 0.66287 0.71134 0.79897 0.87432 0.95775 1.0398
-0.86488 -0.80435 -0.61785 -0.61191 -0.60664 -0.57059 -0.54128 -0.38332 -0.35968 -0.43888 -0.51436 -0.4803 -0.12703 -0.17447 -0.24251 -0.039879 -0.013067 0.023439 0.042959 0.053249 0.054513
-0.030051 -0.029149 -0.048837 -0.04844 -0.049474 -0.057194 -0.068394 -0.38417 -0.86068 -1.0125 -1.1601 -1.3153 -2.1248 -2.0558 -1.991 -2.5537 -2.6646 -2.8546 -3.0096 -3.1742 -3.3311
-0.16488 -0.16567 -0.16287 -0.11368 -0.11034 -0.066712 -0.013944 0.40123 0.66539 0.81394 0.98271 1.2547 2.6985 2.6081 2.5844 3.7324 3.9803 4.4135 4.776 5.1704 5.5536
0.62771 0.60568 0.69782 0.74362 0.75637 0.84433 0.93202 1.5442 1.849 1.9755 2.0795 2.2006 2.8546 2.9249 3.126 4.0132 4.2422 4.6379 4.9657 5.3203 5.6636
1.0933 1.1033 1.0581 1.0295 1.0179 0.94806 0.87999 0.43665 0.43701 0.56782 0.71942 0.83867 1.3756 1.2282 0.98593 1.0532 1.0285 0.993 0.96974 0.9501 0.93564
1.1093 1.1779 1.8146 1.8318 1.86 2.0638 2.2456 3.5007 3.9823 4.2019 4.5125 4.966 7.3155 7.1842 7.1818 9.1009 9.5323 10.304 10.967 11.709 12.445
-0.86365 -0.64692 0.062389 0.12813 0.15571 0.35623 0.54966 1.9258 2.6389 2.9305 3.0988 3.2469 4.1413 4.4848 5.2164 7.2504 7.8122 8.7634 9.5329 10.348 11.124
0.077359 -0.04155 -0.33371 -0.34165 -0.34671 -0.40648 -0.47491 -1.2027 -1.5338 -1.6549 -1.7012 -1.783 -2.3715 -2.5664 -3.0263 -4.3318 -4.7066 -5.3785 -5.9628 -6.6238 -7.288
0.53767 -1.3499 0.6715
1.8339 3.0349 -1.2075
-2.2588 0.7254 0.71724
0.86217 -0.063055 1.6302
0.31877 0.71474 0.48889
-1.3077 -0.20497 1.0347
-0.43359 -0.12414 0.72689
0.34262 1.4897 -0.30344
3.5784 1.409 0.29387
2.7694 1.4172 -0.78728
We can make this file beautiful and searchable if this error is corrected: No commas found in this CSV file in line 0.
1
0
0
0
0
1
1
0
1
0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment