Created
May 29, 2020 21:29
-
-
Save bgmello/a0e31b37a527d26ceea537825aa85388 to your computer and use it in GitHub Desktop.
Class to load data from multiple urls
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import asyncio | |
import requests | |
from concurrent.futures import ThreadPoolExecutor | |
class DataLoader(): | |
def __init__(self, urls, fnames, data_dir, workers=200, verbose=True): | |
''' | |
Class to download data from multiple urls async | |
params: | |
urls(list): urls where the data is stored | |
fnames(list): files where the data will be stored | |
data_dir(str): Data directory | |
worker(int): Number of workers to use async | |
verbose(bool): Print the progress | |
''' | |
self.data_dir = data_dir | |
self.verbose = verbose | |
self.workers = workers | |
self.create_dir(self.data_dir) | |
self.urls = urls | |
self.fnames = fnames | |
def run(self): | |
loop = asyncio.get_event_loop() | |
future = asyncio.ensure_future(self.get_data_async()) | |
return loop.run_until_complete(future) | |
def create_dir(self, directory): | |
if not os.path.exists(directory): | |
os.makedirs(directory) | |
def fetch_result(self, session, url, fname): | |
if os.path.isfile(fname): #if file already exists | |
return -1 | |
with session.get(url) as response: | |
if response.status_code!=200: | |
if self.verbose: | |
print(f"Bad response from url: {url}") | |
return 0 | |
r = response.text | |
with open(fname_offset, 'w') as f: | |
f.write(r) | |
if self.verbose: | |
print(f"Good response from url: {url}") | |
return 1 | |
async def get_data_async(self): | |
with ThreadPoolExecutor(max_workers=self.workers) as executor: | |
with requests.Session() as session: | |
loop = asyncio.get_event_loop() | |
tasks = [ | |
loop.run_in_executor( | |
executor, | |
self.fetch_result, | |
*(session, url, fname) | |
) for url, fname in zip(self.urls, self.fnames) | |
] | |
num_bad_resp = 0 | |
num_good_resp = 0 | |
for resp in await asyncio.gather(*tasks): | |
if resp != -1: | |
num_bad_resp += 1-resp | |
num_good_resp += resp | |
print(f"Number of bad responses: {num_bad_resp}") | |
return num_good_resp, num_bad_resp |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment