Skip to content

Instantly share code, notes, and snippets.

@jbylund
Created March 28, 2022 18:04
Show Gist options
  • Save jbylund/5d65baba3ce093b137e4a78ca7f9d1ee to your computer and use it in GitHub Desktop.
Save jbylund/5d65baba3ce093b137e4a78ca7f9d1ee to your computer and use it in GitHub Desktop.
import argparse
import base64
import csv
import gzip
import json
import multiprocessing
import sys
import tempfile
import logging
import boto3
from pyhive import hive
from thrift.transport import THttpClient
logger = logging.getLogger("databricks_test")
class DatabricksClient:
def __init__(self):
self.cluster_id = "xxx"
self.token = "xxx"
self.workspace_id = "xxx"
self.host = "xxx"
def get_full_url(self):
return f"https://{self.host}/sql/protocolv1/o/{self.workspace_id}/{self.cluster_id}"
def get_auth_headers(self):
auth = base64.standard_b64encode(f"token:{self.token}".encode()).decode()
return {"Authorization": f"Basic {auth}"}
def get_cursor(self):
transport = THttpClient.THttpClient(self.get_full_url())
transport.setCustomHeaders(self.get_auth_headers())
return hive.connect(thrift_transport=transport).cursor()
def execute(self, query):
cursor = self.get_cursor()
cursor.execute(query)
fields = [c[0] for c in cursor.description]
return (dict(zip(fields, vals)) for vals in cursor)
def execute_async(self, query):
cursor = self.get_cursor()
pending_states = (
hive.ttypes.TOperationState.INITIALIZED_STATE,
hive.ttypes.TOperationState.PENDING_STATE,
hive.ttypes.TOperationState.RUNNING_STATE,
)
cursor.execute(query, async_=True)
while cursor.poll().operationState in pending_states:
print("Pending...")
# ok, now can do it the normal way...
def upload_page_wrapper(kwargs):
return upload_one_page(**kwargs)
def upload_one_page(ipage=None, npages=None):
if None in [ipage, npages]:
raise AssertionError("bad args")
def query_and_upload():
client = DatabricksClient()
query = f"""
SELECT
*
FROM
schema.table
WHERE
abs(hash(somefield)) % {npages} = {ipage}
"""
with tempfile.NamedTemporaryFile(mode="w") as tfh:
gzfh = gzip.open(tfh.name, "wt")
rowgen = client.execute(query)
for row in rowgen:
writer = csv.DictWriter(gzfh, list(row))
writer.writeheader()
writer.writerow(row)
break # only write the header once!
else:
raise AssertionError("Expected each partition to have at least 1 record!")
writer.writerows(rowgen)
gzfh.flush()
# then upload to s3?
bucket = boto3.resource("s3").Bucket("superbucket")
tfh.seek(0)
bucket.upload_file(
tfh.name,
f"users/joseph_bylund/{npages:05}_pages/page_{ipage:05}_of_{npages:05}.csv.gz",
)
attempts = 3
while True:
try:
return query_and_upload()
except Exception: # pylint: disable=broad-except
attempts -= 1
if attempts <= 0:
raise
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--pages", type=int, default=400)
return vars(parser.parse_args())
def main():
logging.basicConfig(level=logging.INFO)
args = get_args()
npages = args["pages"]
pages_done = 0
requests = [{"ipage": ipage, "npages": npages} for ipage in range(npages)]
with multiprocessing.Pool() as worker_pool:
for res in worker_pool.imap_unordered(upload_page_wrapper, requests):
pages_done += 1
logger.info(
"Done with %d of %d pages (%f)",
pages_done,
npages,
pages_done * 100 / npages,
)
if "__main__" == __name__:
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment