Skip to content

Instantly share code, notes, and snippets.

@pbrumblay
Created November 5, 2018 22:16
Show Gist options
  • Save pbrumblay/cb6edc3774c3d1a45605074b80a5797a to your computer and use it in GitHub Desktop.
Save pbrumblay/cb6edc3774c3d1a45605074b80a5797a to your computer and use it in GitHub Desktop.
Airflow custom Google Cloud Storage Hook with resumable uploads, partial downloads, and compose (everyone else calls it "concatenating") functionality
from google.cloud import storage
from airflow.hooks.base_hook import BaseHook
from airflow.utils.log.logging_mixin import LoggingMixin
import random
import string
class GCSCustomHook(BaseHook, LoggingMixin):
def __init__(self, storage_conn_id='google_cloud_storage_default'):
self.storage_conn_id = storage_conn_id
self.conn = None
def get_conn(self):
"""
Returns a Google cloud storage object
"""
if self.conn is None:
params = self.get_connection(self.storage_conn_id)
project = params.extra_dejson.get('project')
self.log.info('Getting connection using project %s', project)
self.conn = storage.Client(project)
return self.conn
def list(self, bucket_name, prefix):
conn = self.get_conn()
bucket = conn.lookup_bucket(bucket_name)
if bucket is None:
raise ValueError('Could not find bucket %s' % bucket_name)
return bucket.list_blobs(prefix=prefix)
def compose(self, bucket_name, prefix, new_blob_name):
"""Recursively combine (aka "compose") blob shards in groups of 32"""
blobs_iterator = self.list(bucket_name, prefix)
source_blobs = []
# list() returns an iterator. Get all the entries so we can count them.
for b in blobs_iterator:
source_blobs.append(b)
conn = self.get_conn()
bucket = conn.lookup_bucket(bucket_name)
if bucket is None:
raise ValueError('Could not find bucket %s' % bucket_name)
blob_content_type = None
# recursive base case, if there is only one with a given prefix,
# rename to the desired name
if len(source_blobs) == 1:
self.log.info("Found 1 blob matching prefix, renaming to: %s ", new_blob_name)
bucket.rename_blob(source_blobs[0], new_blob_name)
else:
# create a new prefix to compose blobs in groups of 32 into
random_name = ''.join(random.choice(string.ascii_lowercase) for _ in range(10))
# group the blobs
i = 0
group = -1
list_of_lists = []
for s in source_blobs:
if blob_content_type is None:
blob_content_type = s.content_type
if i % 32 == 0:
list_of_lists.append([])
group = group + 1
self.log.info("Adding blob to group [%s]: %s ", group, s.path)
list_of_lists[group].append(s)
i = i + 1
# for each group, compose under the new name
k = 0
for l in list_of_lists:
new_blob = bucket.blob(random_name + "-" + str(k))
self.log.info("Creating blob: %s ", new_blob.path)
# workaround: https://github.com/googleapis/google-cloud-python/issues/5834
new_blob.content_type = blob_content_type
new_blob.compose(l)
k = k + 1
# delete all blobs with prefix to clean up
for s in source_blobs:
self.log.info("Deleting blob: %s", s.path)
s.delete()
# repeat the process using the new prefix
self.compose(bucket_name, random_name, new_blob_name)
"""
Use google cloud storage api to implement a resumable upload which does not require
the entire file to be written to disk before transmission - the api supports "file like objects"
which, when a chunk size is set can steam data from the source object into GCS.
Only supports using default security context. Cannot use / inherit from airflow GCS hooks since
they use the wrong (deprecated) oauth2 lib.
"""
def resumable_upload(self, file_object, bucket_name, blob_name):
"""
Returns a list of files on the remote system.
:param file_object: a file like object
:type file_object: io.IOBase
:param bucket_name: a GCS bucket
:type bucket_name: str
:param blob_name: the destination path (blob)
:type blob_name: str
"""
conn = self.get_conn()
bucket = conn.lookup_bucket(bucket_name)
if bucket is None:
raise ValueError('Could not find bucket ' % bucket_name)
self.log.info("Found bucket starting upload to gs://%s/%s", bucket_name, blob_name)
blob = bucket.blob(blob_name)
blob.chunk_size = 1024 * 1024 * 3 # 3mb
blob.upload_from_file(file_object)
def download_file_part(self, bucket_name, blob_name, start, end, file_name):
conn = self.get_conn()
bucket = conn.lookup_bucket(bucket_name)
if bucket is None:
raise ValueError('Could not find bucket ' % bucket_name)
self.log.info("Found bucket %s. Downloading file [%s] at %s to %s ", bucket_name, blob_name)
blob = bucket.blob(blob_name)
blob.download_to_filename(file_name, start=start, end=end)
def download_file_string(self, bucket_name, blob_name):
conn = self.get_conn()
bucket = conn.lookup_bucket(bucket_name)
if bucket is None:
raise ValueError('Could not find bucket ' % bucket_name)
self.log.info("Found bucket %s. Downloading file [%s] at %s to %s ", bucket_name, blob_name)
blob = bucket.blob(blob_name)
blob.download_as_string()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment