Skip to content

Instantly share code, notes, and snippets.

@kijes
Created November 27, 2017 08:43
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save kijes/163bbceca3b6fb25f210f1d2033a4f97 to your computer and use it in GitHub Desktop.
Preparing data for training on FloydHub by downloading from Google Drive
import logging
import os
import requests
from common import utils
CHUNK_SIZE = 32768
def download_file_from_google_drive(id, destination):
URL = "https://docs.google.com/uc?export=download"
session = requests.Session()
response = session.get(URL, params={'id': id}, stream=True)
token = get_confirm_token(response)
if token:
params = {'id': id, 'confirm': token}
response = session.get(URL, params=params, stream=True)
save_response_content(response, destination)
def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
def save_response_content(response, destination):
with open(destination, "wb") as f:
for chunk in response.iter_content(CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
def download_files_from_google_drive(file_id, output_path):
logging.info('Downloading data from Google Drive')
output_path_dir = os.path.dirname(os.path.abspath(output_path))
utils.makedirs(output_path_dir)
logging.debug('Downloading data [%s] to [%s]' % (file_id, output_path))
download_file_from_google_drive(file_id, output_path)
import logging
import sys
from common.data import download_from_drive
from common.data import unpack
def main(args):
prefix = ""
FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
logging.basicConfig(format=FORMAT, level=logging.DEBUG)
logging.info('Preparing Redux data')
file_id = "0B7lvbTNg1cHtOWtyQU9LbU1yaE0"
output_dir = prefix + "/output/train_data.tar.gz"
download_from_drive.download_files_from_google_drive(file_id, output_dir)
file_id = "0B7lvbTNg1cHtWUdXcG9US0dhclk"
output_dir = prefix + "/output/tf_vgg16.tar.gz"
download_from_drive.download_files_from_google_drive(file_id, output_dir)
archive_path = prefix + "/output/train_data.tar.gz"
output_dir = prefix + "/output"
unpack.unpack_data(archive_path, output_dir, True)
archive_path = prefix + "/output/tf_vgg16.tar.gz"
output_dir = prefix + "/output"
unpack.unpack_data(archive_path, output_dir, True)
logging.info('Prepared data')
if __name__ == '__main__':
main(sys.argv[1:])
#!/bin/bash
SCRIPT_DIR=$(cd $(dirname "$0"); pwd -P)
export PYTHONPATH=${SCRIPT_DIR}
python py_redux/prepare/prepare_train_full_fh.py
#!/bin/bash
SCRIPT_DIR=$(cd $(dirname "$0"); pwd -P)
export PYTHONPATH=${SCRIPT_DIR}
floyd run --env tensorflow-1.2 "bash prepare_train_full_fh.sh"
import logging
import os
import tarfile
import zipfile
from common import utils
def unpack_data(archive_path, output_dir, remove_archive):
utils.makedirs(output_dir)
logging.info('Unpacking [%s] to [%s]' % (archive_path, output_dir))
if archive_path.endswith("zip"):
with zipfile.ZipFile(archive_path, "r") as zip_ref:
zip_ref.extractall(output_dir)
elif archive_path.endswith("tar.gz") or archive_path.endswith("tar") or archive_path.endswith("tgz"):
tar = tarfile.open(archive_path)
tar.extractall(output_dir)
tar.close()
else:
raise Exception('Unknown archive type')
if remove_archive:
logging.info('Removing archive [%s]' % archive_path)
os.remove(archive_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment