Created
May 19, 2024 09:20
-
-
Save baberabb/35c11b7020fe22b1156f3338550549ae to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import os | |
import random | |
from collections import deque | |
import ssl | |
from urllib.parse import urlparse | |
import aiohttp | |
import polars as pl | |
import aiofiles | |
from aiohttp import TCPConnector | |
# from limiter import Limiter | |
import certifi | |
from tenacity import retry, wait_exponential, wait_random_exponential | |
from tqdm.asyncio import tqdm_asyncio | |
# limit_downloads = Limiter(rate=20, capacity=1000, consume=1) | |
BATCH_SIZE = 1000 | |
BASE_DOWNLOAD_PATH = "/openalex_downloads" | |
TIME_OUT = 60 | |
USER_AGENTS = [ | |
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", | |
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", | |
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", | |
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Firefox/89.0", | |
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:89.0) Gecko/20100101 Firefox/89.0", | |
] | |
HEADER = lambda x: { | |
"User-Agent": random.choice(USER_AGENTS), | |
"Referer": x, | |
} | |
async def get_request(session, queries: list, uuid: str) -> tuple[str, bytes | None]: | |
# Create SSL context | |
ssl_context = ssl.create_default_context() | |
ssl_context.check_hostname = True | |
ssl_context.verify_mode = ssl.CERT_REQUIRED | |
for query in queries: | |
try: | |
async with session.get(url=query, headers=HEADER(query)) as response: | |
content = await response.read() | |
uuid = urlparse(uuid).path.replace(r"/", "") | |
filename = os.path.join(BASE_DOWNLOAD_PATH, f"{uuid}.pdf") | |
file_number = 1 | |
while os.path.exists(filename): | |
filename = os.path.join( | |
BASE_DOWNLOAD_PATH, f"{uuid}_{file_number}.pdf" | |
) | |
file_number += 1 | |
# async with aiofiles.open(filename, "wb") as f: | |
# await f.write(content) | |
return filename, content | |
except Exception as e: | |
print(f"An error occurred with query {query}: {e}") | |
continue # Try the next query in the list | |
# If all queries fail, handle the last one properly | |
print(f"All queries failed: {queries}") | |
return "error", None | |
async def get_batched(session, batch, seen: set = None): | |
if not seen: | |
seen = set() | |
tasks = [] | |
for q in batch: | |
if q: | |
task = asyncio.ensure_future(get_request(session, queries=q[0], uuid=q[1])) | |
tasks.append(task) | |
return await tqdm_asyncio.gather( | |
*tasks, desc="Collecting batch", leave=True, position=0 | |
) | |
async def main(file_loc): | |
# or df = datasets.load_dataset(url, split="train").to_polars().lazy() | |
df = pl.scan_parquet(file_loc) | |
df = ( | |
df.with_columns(pl.col("pdf_url").str.split(",")) | |
.select(["identifier", "pdf_url"]) | |
.collect(streaming=True) | |
.iter_rows(named=True) | |
) | |
batches = deque() | |
output = [] | |
# Create SSL context | |
ssl_context = ssl.create_default_context() | |
ssl_context.check_hostname = True | |
ssl_context.verify_mode = ssl.CERT_REQUIRED | |
timeout = aiohttp.ClientTimeout(total=TIME_OUT) | |
for row in df: | |
batches.append((row["pdf_url"], row["identifier"])) | |
if len(batches) == BATCH_SIZE: | |
async with aiohttp.ClientSession( | |
connector=TCPConnector(ssl=ssl_context, limit=50), | |
timeout=timeout, | |
) as session: | |
responses = await get_batched(session, batches) | |
for filename, content in responses: | |
if content: | |
with open(filename, "wb") as f: | |
f.write(content) | |
output.append(filename) | |
batches.clear() | |
if batches: | |
async with aiohttp.ClientSession( | |
connector=TCPConnector(ssl=ssl_context, limit=50), timeout=timeout | |
) as session: | |
responses = await get_batched(session, batches) | |
print("Saving Batch") | |
for filename, content in responses: | |
if content: | |
with open(filename, "wb") as f: | |
f.write(content) | |
output.append(filename) | |
print("Batch Saved") | |
return output | |
if __name__ == "__main__": | |
FILE_LOCATION = "/Openalex Extraction 2.parquet" | |
results = asyncio.run(main(FILE_LOCATION)) | |
print(results) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment