Created
March 28, 2022 18:04
-
-
Save jbylund/5d65baba3ce093b137e4a78ca7f9d1ee to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import base64 | |
import csv | |
import gzip | |
import json | |
import multiprocessing | |
import sys | |
import tempfile | |
import logging | |
import boto3 | |
from pyhive import hive | |
from thrift.transport import THttpClient | |
logger = logging.getLogger("databricks_test") | |
class DatabricksClient: | |
def __init__(self): | |
self.cluster_id = "xxx" | |
self.token = "xxx" | |
self.workspace_id = "xxx" | |
self.host = "xxx" | |
def get_full_url(self): | |
return f"https://{self.host}/sql/protocolv1/o/{self.workspace_id}/{self.cluster_id}" | |
def get_auth_headers(self): | |
auth = base64.standard_b64encode(f"token:{self.token}".encode()).decode() | |
return {"Authorization": f"Basic {auth}"} | |
def get_cursor(self): | |
transport = THttpClient.THttpClient(self.get_full_url()) | |
transport.setCustomHeaders(self.get_auth_headers()) | |
return hive.connect(thrift_transport=transport).cursor() | |
def execute(self, query): | |
cursor = self.get_cursor() | |
cursor.execute(query) | |
fields = [c[0] for c in cursor.description] | |
return (dict(zip(fields, vals)) for vals in cursor) | |
def execute_async(self, query): | |
cursor = self.get_cursor() | |
pending_states = ( | |
hive.ttypes.TOperationState.INITIALIZED_STATE, | |
hive.ttypes.TOperationState.PENDING_STATE, | |
hive.ttypes.TOperationState.RUNNING_STATE, | |
) | |
cursor.execute(query, async_=True) | |
while cursor.poll().operationState in pending_states: | |
print("Pending...") | |
# ok, now can do it the normal way... | |
def upload_page_wrapper(kwargs): | |
return upload_one_page(**kwargs) | |
def upload_one_page(ipage=None, npages=None): | |
if None in [ipage, npages]: | |
raise AssertionError("bad args") | |
def query_and_upload(): | |
client = DatabricksClient() | |
query = f""" | |
SELECT | |
* | |
FROM | |
schema.table | |
WHERE | |
abs(hash(somefield)) % {npages} = {ipage} | |
""" | |
with tempfile.NamedTemporaryFile(mode="w") as tfh: | |
gzfh = gzip.open(tfh.name, "wt") | |
rowgen = client.execute(query) | |
for row in rowgen: | |
writer = csv.DictWriter(gzfh, list(row)) | |
writer.writeheader() | |
writer.writerow(row) | |
break # only write the header once! | |
else: | |
raise AssertionError("Expected each partition to have at least 1 record!") | |
writer.writerows(rowgen) | |
gzfh.flush() | |
# then upload to s3? | |
bucket = boto3.resource("s3").Bucket("superbucket") | |
tfh.seek(0) | |
bucket.upload_file( | |
tfh.name, | |
f"users/joseph_bylund/{npages:05}_pages/page_{ipage:05}_of_{npages:05}.csv.gz", | |
) | |
attempts = 3 | |
while True: | |
try: | |
return query_and_upload() | |
except Exception: # pylint: disable=broad-except | |
attempts -= 1 | |
if attempts <= 0: | |
raise | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--pages", type=int, default=400) | |
return vars(parser.parse_args()) | |
def main(): | |
logging.basicConfig(level=logging.INFO) | |
args = get_args() | |
npages = args["pages"] | |
pages_done = 0 | |
requests = [{"ipage": ipage, "npages": npages} for ipage in range(npages)] | |
with multiprocessing.Pool() as worker_pool: | |
for res in worker_pool.imap_unordered(upload_page_wrapper, requests): | |
pages_done += 1 | |
logger.info( | |
"Done with %d of %d pages (%f)", | |
pages_done, | |
npages, | |
pages_done * 100 / npages, | |
) | |
if "__main__" == __name__: | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment