Skip to content

Instantly share code, notes, and snippets.

@TomAugspurger
Created November 6, 2017 12:19
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save TomAugspurger/a794b3a153548966baf5b62e0806b0b4 to your computer and use it in GitHub Desktop.
Save TomAugspurger/a794b3a153548966baf5b62e0806b0b4 to your computer and use it in GitHub Desktop.
from urllib.parse import urlparse
import lightgbm as lgb
import threading
def _parse_machines(workers, listen_port):
"""From dask worker info to LightGBM mlist"""
# TODO: Assert that we're using TCP?
mlist = ['127.0.0.1:{}'.format(listen_port)] + [urlparse(worker).netloc
for worker in workers]
return mlist
def train(params, train):
bst = lgb.train(params, train)
return bst
def main():
from sklearn.datasets import make_classification
from distributed import Client
# setup
port = 12400
c = Client()
print('scheduler', c.scheduler_info()['address'])
machines = _parse_machines(c.scheduler_info()['workers'], port)
params = {"machines": machines}
# Data
X, y = make_classification()
dset = lgb.Dataset(X, label=y)
# Train
train(params, dset)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment