Created
June 5, 2017 16:03
-
-
Save mrocklin/a92785743744b5c698984e16b7065037 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ==== dask-ps | |
import dask | |
import dask.array as da | |
from dask import delayed | |
from dask_glm import families | |
from dask_glm.algorithms import lbfgs | |
from distributed import LocalCluster, Client, worker_client, Variable, Queue | |
import numpy as np | |
import time | |
from sklearn import datasets | |
from sklearn.linear_model import LogisticRegression | |
from sklearn.metrics import roc_auc_score | |
STEP_SIZE = 1.0 | |
N = 100000 | |
D = 10 | |
ITER = 10 | |
X_local, y_local = datasets.make_classification(n_classes=2, n_samples=N, n_features=D) | |
X = da.from_array(X_local, N // 8) | |
y = da.from_array(y_local, N // 8) | |
XD = X.to_delayed().flatten().tolist() # a list of numpy arrays, one for each chunk | |
yD = y.to_delayed().flatten().tolist() | |
STEP_SIZE /= len(XD) # need to adjust based on parallelism for convegence? | |
family = families.Logistic() | |
pointwise_gradient = family.pointwise_gradient | |
pointwise_loss = family.pointwise_loss | |
from contextlib import contextmanager | |
@contextmanager | |
def duration(s): | |
start = time.time() | |
yield | |
end = time.time() | |
print('DURATION', s, end - start) | |
def local_update(X, y, beta): | |
return pointwise_gradient(beta, X, y) | |
def parameter_server(): | |
beta = np.zeros(D) | |
gti = np.zeros(D) | |
with worker_client() as c: | |
betas = Variable('betas', client=c) | |
stop = Variable('stop', client=c) | |
updates = Queue('updates', client=c) | |
[future_beta] = c.scatter([beta]) | |
betas.set(future_beta) | |
while not stop.get(): | |
update = updates.get().result() | |
print("PS: received update: %s" % update) | |
gti += update ** 2 | |
adj_grad = update / (1e-6 + np.sqrt(gti)) | |
beta = beta - STEP_SIZE * adj_grad | |
[future_beta] = c.scatter([beta]) | |
betas.set(future_beta) | |
print("PS: pushed beta: %s" % beta) | |
def worker(X, y, i): | |
with worker_client(separate_thread=False) as c: | |
betas = Variable('betas', client=c) | |
stop = Variable('stop', client=c) | |
updates = Queue('updates', client=c) | |
while not stop.get(): | |
beta = betas.get().result() | |
print("Worker %d: received beta: %s" % (i, beta)) | |
update = local_update(X, y, beta) #.compute() | |
[update_future] = c.scatter([update]) | |
updates.put(update_future) | |
print("Worker %d: pushed update: %s" % (i, update)) | |
if __name__ == '__main__': | |
cluster = LocalCluster(n_workers=0) | |
cluster.start_worker(1, name="ps") | |
cluster.start_worker(4, name="w1") | |
cluster.start_worker(4, name="w2") | |
client = Client(cluster) | |
betas = Variable('betas') | |
stop = Variable('stop') | |
updates = Queue('updates') | |
stop.set(False) | |
# start parameter server | |
res_ps = client.submit(parameter_server, workers=['ps'], pure=False) | |
# start workers computing | |
XD = client.compute(XD) | |
yD = client.compute(yD) | |
res_workers = [client.submit(worker, xx, yy, i) | |
for i, (xx, yy) in enumerate(zip(XD, yD))] | |
# Stop if beta converges or after ten seconds | |
last = -1 | |
count_same = 0 | |
start = time.time() | |
while count_same < 5: | |
beta = betas.get().result() | |
if np.allclose(last, beta, atol=1e-5): | |
count_same += 1 | |
else: | |
count_same = 0 | |
last = beta | |
time.sleep(.2) | |
print('updates queue size:', updates.qsize()) | |
if time.time() - start > 10: | |
break | |
stop.set(True) | |
print("Converged to", beta) | |
client.gather(res_workers) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment