Skip to content

Instantly share code, notes, and snippets.

@JanCBrammer
Created October 16, 2023 12:29
Show Gist options
  • Save JanCBrammer/662f47f35ab153868d846f98655e53ea to your computer and use it in GitHub Desktop.
Save JanCBrammer/662f47f35ab153868d846f98655e53ea to your computer and use it in GitHub Desktop.
Download PubChem
"""Download gzipped SDF files from ftp.ncbi.nlm.nih.gov/pubchem/.
https://openbook.rheinwerk-verlag.de/python/34_003.html
https://docs.python.org/3/library/ftplib.html
https://pubchem.ncbi.nlm.nih.gov/docs/downloads#section=From-the-PubChem-FTP-Site
"""
from ftplib import FTP
from ftplib import all_errors as FTPException
import hashlib
from pathlib import Path
from typing import Iterator
from contextlib import contextmanager
@contextmanager
def pubchem_ftp_client(dataset_directory: str):
client = FTP("ftp.ncbi.nlm.nih.gov")
client.login()
client.cwd(dataset_directory)
try:
yield client
except FTPException as exception:
print(exception)
finally:
client.close()
class LineData:
def __init__(self):
self.content = ""
def __call__(self, line):
self.content += line.split()[0].strip()
class MD5:
"""Compute MD5 hash from byte stream."""
def __init__(self):
self.hash_function = hashlib.md5()
def __call__(self, block: bytes):
self.hash_function.update(block)
@property
def hash(self) -> str:
return self.hash_function.hexdigest()
def _fetch_gzipped_sdf_filenames(dataset_directory: str) -> list[str]:
"""Fetch names of all gzipped SDF from FTP server."""
with pubchem_ftp_client(dataset_directory) as client:
return [
file_description[0]
for file_description in list(client.mlsd())
if file_description[0].endswith(".sdf.gz")
]
def _fetch_gzipped_sdf(
filename: str,
destination_directory: str,
dataset_directory: str,
overwrite_file: bool,
) -> str:
"""Fetch gzipped SDF from FTP server.
Validates the gzipped SDF and writes it to the file system.
"""
filepath = Path(destination_directory).joinpath(filename)
if filepath.exists() and not overwrite_file:
print(f"{filepath.as_posix()} already exists. Skipping download.")
return ""
md5_local = MD5()
def distribute_ftp_callback(block: bytes):
md5_local(block)
gzipped_sdf.write(block)
with filepath.open("wb") as gzipped_sdf, pubchem_ftp_client(
dataset_directory
) as client:
client.retrbinary(f"RETR {filename}", distribute_ftp_callback)
md5_server = _fetch_gzipped_sdf_hash(filename, dataset_directory)
if md5_server:
# Some PubChem datasets (e.g., Compound 3D) don't have MD5 hashes.
if md5_local.hash != md5_server:
print(
f"The hash of {filepath.as_posix()} doesn't match it's corresponding hash on the FTP server. Removing the file locally."
)
filepath.unlink()
return ""
return filepath.as_posix()
def _fetch_gzipped_sdf_hash(filename: str, dataset_directory: str) -> str:
"""Fetch MD5 hash from FTP server."""
md5 = LineData()
with pubchem_ftp_client(dataset_directory) as client:
try:
client.retrlines(f"RETR {filename}.md5", md5)
except FTPException:
# Some PubChem datasets (e.g., Compound 3D) don't have MD5 hashes.
pass
return md5.content
def download_all_sdf(
destination_directory: str,
dataset_directory: str,
overwrite_files: bool = False,
) -> Iterator[str]:
"""Generator yielding file paths of successfully downloaded gzipped SDF.
`dataset_directory` can be one of
`pubchem/Compound/CURRENT-Full/SDF/`,
`pubchem/Substance/CURRENT-Full/SDF/`, or
`pubchem/Compound_3D/01_conf_per_cmpd/SDF`.
"""
for filename in _fetch_gzipped_sdf_filenames(dataset_directory):
if filepath := _fetch_gzipped_sdf(
filename, destination_directory, dataset_directory, overwrite_files
):
yield filepath
def get_id(molfile: str) -> str:
return molfile.split()[0].strip()
import pytest
from sdf_pipeline import pubchem
@pytest.mark.parametrize(
"dataset_directory, expected_sdf_paths",
[
(
"pubchem/Compound/CURRENT-Full/SDF/",
[
"Compound_000000001_000500000.sdf.gz",
"Compound_001000001_001500000.sdf.gz",
],
),
(
"pubchem/Substance/CURRENT-Full/SDF/",
[
"Substance_000500001_001000000.sdf.gz",
"Substance_000000001_000500000.sdf.gz",
],
),
(
"pubchem/Compound_3D/01_conf_per_cmpd/SDF",
[
"00000001_00025000.sdf.gz",
"00025001_00050000.sdf.gz",
],
),
],
)
def test_download_all_sdf(tmp_path, dataset_directory, expected_sdf_paths):
sdf_path_generator = pubchem.download_all_sdf(
destination_directory=str(tmp_path), dataset_directory=dataset_directory
)
sdf_paths = {next(sdf_path_generator) for _ in range(2)}
assert sdf_paths == {str(tmp_path / sdf_path) for sdf_path in expected_sdf_paths}
assert all((tmp_path / sdf_path).exists() for sdf_path in sdf_paths)
@pytest.mark.parametrize(
"dataset_directory, expected_n_sdf_paths",
[
("pubchem/Compound/CURRENT-Full/SDF/", 338),
("pubchem/Substance/CURRENT-Full/SDF/", 894),
("pubchem/Compound_3D/01_conf_per_cmpd/SDF", 6646),
],
)
def test_fetch_gzipped_sdf_filenames(dataset_directory, expected_n_sdf_paths):
sdf_paths = pubchem._fetch_gzipped_sdf_filenames(dataset_directory)
assert len(sdf_paths) == expected_n_sdf_paths
@pytest.mark.parametrize(
"dataset_directory, filename, expected_hash",
[
(
"pubchem/Compound/CURRENT-Full/SDF/",
"Compound_000000001_000500000.sdf.gz",
"81d318fd569898ffc1506478d6f3389b",
),
(
"pubchem/Substance/CURRENT-Full/SDF/",
"Substance_000500001_001000000.sdf.gz",
"5365255fe6acbe94a9d48c7ca1a745b9",
),
("pubchem/Compound_3D/01_conf_per_cmpd/SDF", "00000001_00025000.sdf.gz", ""),
],
)
def test_fetch_gzipped_sdf_hash(dataset_directory, filename, expected_hash):
assert pubchem._fetch_gzipped_sdf_hash(filename, dataset_directory) == expected_hash
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment