Skip to content

Instantly share code, notes, and snippets.

@mrocklin
Created June 5, 2017 16:03
Show Gist options
  • Save mrocklin/a92785743744b5c698984e16b7065037 to your computer and use it in GitHub Desktop.
Save mrocklin/a92785743744b5c698984e16b7065037 to your computer and use it in GitHub Desktop.
# ==== dask-ps
import dask
import dask.array as da
from dask import delayed
from dask_glm import families
from dask_glm.algorithms import lbfgs
from distributed import LocalCluster, Client, worker_client, Variable, Queue
import numpy as np
import time
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score
STEP_SIZE = 1.0
N = 100000
D = 10
ITER = 10
X_local, y_local = datasets.make_classification(n_classes=2, n_samples=N, n_features=D)
X = da.from_array(X_local, N // 8)
y = da.from_array(y_local, N // 8)
XD = X.to_delayed().flatten().tolist() # a list of numpy arrays, one for each chunk
yD = y.to_delayed().flatten().tolist()
STEP_SIZE /= len(XD) # need to adjust based on parallelism for convegence?
family = families.Logistic()
pointwise_gradient = family.pointwise_gradient
pointwise_loss = family.pointwise_loss
from contextlib import contextmanager
@contextmanager
def duration(s):
start = time.time()
yield
end = time.time()
print('DURATION', s, end - start)
def local_update(X, y, beta):
return pointwise_gradient(beta, X, y)
def parameter_server():
beta = np.zeros(D)
gti = np.zeros(D)
with worker_client() as c:
betas = Variable('betas', client=c)
stop = Variable('stop', client=c)
updates = Queue('updates', client=c)
[future_beta] = c.scatter([beta])
betas.set(future_beta)
while not stop.get():
update = updates.get().result()
print("PS: received update: %s" % update)
gti += update ** 2
adj_grad = update / (1e-6 + np.sqrt(gti))
beta = beta - STEP_SIZE * adj_grad
[future_beta] = c.scatter([beta])
betas.set(future_beta)
print("PS: pushed beta: %s" % beta)
def worker(X, y, i):
with worker_client(separate_thread=False) as c:
betas = Variable('betas', client=c)
stop = Variable('stop', client=c)
updates = Queue('updates', client=c)
while not stop.get():
beta = betas.get().result()
print("Worker %d: received beta: %s" % (i, beta))
update = local_update(X, y, beta) #.compute()
[update_future] = c.scatter([update])
updates.put(update_future)
print("Worker %d: pushed update: %s" % (i, update))
if __name__ == '__main__':
cluster = LocalCluster(n_workers=0)
cluster.start_worker(1, name="ps")
cluster.start_worker(4, name="w1")
cluster.start_worker(4, name="w2")
client = Client(cluster)
betas = Variable('betas')
stop = Variable('stop')
updates = Queue('updates')
stop.set(False)
# start parameter server
res_ps = client.submit(parameter_server, workers=['ps'], pure=False)
# start workers computing
XD = client.compute(XD)
yD = client.compute(yD)
res_workers = [client.submit(worker, xx, yy, i)
for i, (xx, yy) in enumerate(zip(XD, yD))]
# Stop if beta converges or after ten seconds
last = -1
count_same = 0
start = time.time()
while count_same < 5:
beta = betas.get().result()
if np.allclose(last, beta, atol=1e-5):
count_same += 1
else:
count_same = 0
last = beta
time.sleep(.2)
print('updates queue size:', updates.qsize())
if time.time() - start > 10:
break
stop.set(True)
print("Converged to", beta)
client.gather(res_workers)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment