Skip to content

Instantly share code, notes, and snippets.

@mojaveazure
Created April 30, 2020 07:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mojaveazure/922aa904b9ac212e627f522c2b816a52 to your computer and use it in GitHub Desktop.
Save mojaveazure/922aa904b9ac212e627f522c2b816a52 to your computer and use it in GitHub Desktop.
Script to generate an H5AD file following Scanpy's PBMC 3k tutorial
#!/usr/bin/env python3
"""Generate an H5AD file from the PBMC3k dataset"""
import os
import sys
import time
import shutil
import urllib
import logging
import tarfile
import argparse
import warnings
# Check version information
if sys.version_info.major != 3:
sys.exit("This script requires Python 3.6 or greater")
elif sys.version_info.minor < 6:
sys.exit("This script requires Python 3.6 or greater")
# Get third-party modules
with warnings.catch_warnings():
warnings.simplefilter('ignore')
try:
import numpy
import pandas
import scanpy
except ImportError as err: # type: ModuleNotFoundError
sys.exit("Please install %s" % err.name)
# Global constants
URL = 'http://cf.10xgenomics.com/samples/cell-exp/1.1.0/pbmc3k/pbmc3k_filtered_gene_bc_matrices.tar.gz' # type: str
STRUCT_10X = 'filtered_gene_bc_matrices/hg19' # type: str
LOG_FORMAT = '%(asctime)s %(levelname)s:\t%(message)s' # type: str
DATE_FORMAT = '%Y-%m-%d %H:%M:%S' # type: str
LOG_LEVELS = { # type: Dict[str, int]
'debug': logging.DEBUG,
'info': logging.INFO,
'warning': logging.WARNING,
'error': logging.ERROR,
'critical': logging.CRITICAL
}
def _runtype(value): # type: (str) -> str
if value != 'run':
raise argparse.ArgumentTypeError("Please pass 'run' instead")
return value
def download(url, outname=None, outdir=os.getcwd()): # type: (str, Optional[str], str) -> str
"""Download a file"""
os.makedirs(outdir, exist_ok=True)
if not outname:
outname = os.path.basename(url) # type: str
outname = os.path.join(outdir, outname) # type: str
logging.info("Downloading file to %s", outname)
dl_start = time.time() # type: float
_ = urllib.request.urlretrieve(url, outname) # type: _
logging.info("File downloaded to %s", outname)
logging.debug("Downloading file took %s seconds", round(time.time() - dl_start))
return outname
if __name__ == "__main__":
parser = argparse.ArgumentParser() # type: argparse.ArgumentParser
parser.add_argument( # Run the thing
'run',
type=_runtype,
help="Run the program"
)
parser.add_argument( # How far are we going
'mode',
type=str,
choices=('raw', 'final'),
help="How far through the pipeline should we run? Choose from %(choices)s"
)
parser.add_argument( # Output directory
'--outdir',
dest='outdir',
type=str,
required=False,
#default=os.path.join(os.getcwd(), 'pbmc3k_h5ad'),
default=os.getcwd(),
metavar='outdir',
help="Path to output directory; defaults to %(default)s"
)
parser.add_argument( # Verbosity
'--verbosity',
dest='verbosity',
type=str,
required=False,
default='info',
choices=LOG_LEVELS.keys(),
metavar='level',
help="Verbosity level, choose from %(choices)s; defaults to %(default)s"
)
if not sys.argv[1:]:
sys.exit(parser.print_help())
args = vars(parser.parse_args()) # type: Dict[str, Any]
logging.basicConfig(
format=LOG_FORMAT,
datefmt=DATE_FORMAT,
level=LOG_LEVELS[args['verbosity']]
)
os.makedirs(args['outdir'], exist_ok=True)
tarball = download(url=URL, outdir=args['outdir']) # type: str
adata_save = os.path.join(args['outdir'], 'pbmc3k_%s.h5ad' % args['mode']) # type: str
os.makedirs(os.path.dirname(adata_save), exist_ok=True)
# Extract data
logging.info("Extracting tarball")
extract_start = time.time() # type: float
tarhandle = tarfile.open(tarball) # type: tarfile.TarFile
tarhandle.extractall(path=args['outdir'])
logging.debug("Extraction took %s seconds", round(time.time() - extract_start, 3))
# Load data
logging.info("Loading data")
load_start = time.time() # type: float
dl_dir = os.path.join(args['outdir'], STRUCT_10X) # type: str
adata = scanpy.read_10x_mtx(path=dl_dir, var_names='gene_symbols', cache=True) # type: anndata._core.anndata.AnnData
adata.var_names_make_unique()
logging.debug("Loading data took %s seconds", round(time.time() - load_start, 3))
# Clean up
logging.info("Cleaning up downloaded data")
clean_start = time.time() # type: float
os.remove(tarball)
print(dl_dir)
shutil.rmtree(os.path.dirname(dl_dir))
logging.debug("Cleaning downloaded data took %s seconds", round(time.time() - clean_start, 3))
if args['mode'] == 'raw':
adata.write(adata_save)
logging.info("Final H5AD file can be found at %s", adata_save)
sys.exit(0)
# Preprocessing
logging.info("Preprocessing")
preproc_start = time.time() # type: float
scanpy.pp.filter_cells(adata, min_genes=200)
scanpy.pp.filter_genes(adata, min_cells=3)
mito_genes = adata.var_names.str.startswith('MT-') # type: numpy.ndarray[numpy.bool_]
adata.obs['percent_mito'] = numpy.sum(adata[:, mito_genes].X, axis=1).A1 / numpy.sum(adata.X, axis=1).A1
adata.obs['n_counts'] = adata.X.sum(axis=1).A1
adata = adata[adata.obs.n_genes < 2500, :] # type: anndata._core.anndata.AnnData
adata = adata[adata.obs.percent_mito < 0.05, :] # type: anndata._core.anndata.AnnData
logging.debug("Preprocessing took %s seconds", round(time.time() - preproc_start, 3))
# Normalization
logging.info("Normalizing the data")
norm_start = time.time() # type: float
scanpy.pp.normalize_total(adata, target_sum=1e4)
scanpy.pp.log1p(adata)
adata.raw = adata
# HVF
logging.info("Finding highly-variable genes")
hvf_start = time.time() # type: float
scanpy.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5)
adata = adata[:, adata.var.highly_variable]
logging.debug("Finding highly-variable genes took %s seconds", round(time.time() - hvf_start, 3))
# Scaling and regression
logging.info("Scaling and regressing")
scale_start = time.time() # type: float
scanpy.pp.regress_out(adata, ['n_counts', 'percent_mito'])
scanpy.pp.scale(adata, max_value=10)
logging.debug("Scaling and regession took %s seconds", round(time.time() - scale_start, 3))
# PCA
logging.info("Running PCA")
pca_start = time.time() # type: float
scanpy.tl.pca(adata, svd_solver='arpack')
logging.debug("PCA took %s seconds", round(time.time() - pca_start, 3))
# Nearest neighbors
logging.info("Finding nearest neighbors")
nn_start = time.time() # type: float
scanpy.pp.neighbors(adata, n_neighbors=10, n_pcs=40)
logging.debug("Finding nearest neighbors took %s seconds", round(time.time() - nn_start, 3))
# UMAP
logging.info("Running UMAP")
umap_start = time.time() # type: float
scanpy.tl.umap(adata)
logging.debug("Running UMAP took %s seconds", round(time.time() - umap_start, 3))
# Clustering
logging.info("Clustering")
cl_start = time.time() # type: float
scanpy.tl.leiden(adata)
logging.debug("Clustering took %s seconds", round(time.time() - cl_start, 3))
adata.write(adata_save)
logging.info("Final H5AD file can be found at %s", adata_save)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment