Last active
April 14, 2019 06:35
-
-
Save ivirshup/3b665d710729cef59b2cc80c0765cea8 to your computer and use it in GitHub Desktop.
Weird data projection bug
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scanpy as sc | |
import numpy as np | |
import pandas as pd | |
from scipy import sparse | |
from umap import UMAP | |
from itertools import repeat, chain | |
def preprocess(adata): | |
# adata.var["mito"] = adata.var["gene_symbols"].str.startswith("MT-") | |
# sc.pp.calculate_qc_metrics(adata, qc_vars=["mito"], inplace=True) | |
sc.pp.normalize_per_cell(adata, counts_per_cell_after=10000) | |
sc.pp.log1p(adata) | |
return adata | |
def pca_update(tgt, src, inplace=True): | |
# TODO: Make sure we know the settings from src | |
if not inplace: | |
tgt = tgt.copy() | |
if sparse.issparse(tgt.X): | |
X = tgt.X.toarray() | |
else: | |
X = tgt.X.copy() | |
X -= np.asarray(tgt.X.mean(axis=0)) | |
tgt_pca = np.dot(X, src.varm["PCs"]) | |
tgt.obsm["X_pca"] = tgt_pca | |
return tgt | |
def simulate_doublets(adata, frac=.5): | |
"""Simulate doublets from count data. | |
Params | |
------ | |
adata | |
The anndata object to sample from. Must have count data. | |
frac | |
Fraction of total cells to simulate. | |
""" | |
m, n = adata.X.shape | |
n_doublets = int(np.round(m * frac)) | |
pos_idx = np.array(list(chain.from_iterable(map(lambda x: repeat(x, 2), range(n_doublets))))) | |
combos = np.random.randint(0, m, (n_doublets * 2)) | |
pos = sparse.csr_matrix( | |
(np.ones_like(combos, dtype=adata.X.dtype), (pos_idx, combos)), | |
shape=(n_doublets, m) | |
) | |
dblX = pos * adata.X | |
# TODO: Downsample total counts | |
srcs = np.sort(combos.reshape(n_doublets, 2), axis=1) | |
obs = pd.DataFrame(srcs, columns=["src1", "src2"]) | |
var = pd.DataFrame(index=adata.var_names) | |
return sc.AnnData(dblX, obs=obs, var=var) | |
# Setup | |
pbmc = sc.datasets.pbmc3k() | |
sc.pp.filter_genes(pbmc, min_counts=1) | |
pbmc.var["gene_symbols"] = pbmc.var.index | |
pbmc.var.set_index("gene_ids", inplace=True) | |
dblt = simulate_doublets(pbmc, 1) # Does not happen for `frac!=1` afaik | |
dblt.var["gene_symbols"] = pbmc.var["gene_symbols"] | |
pbmc = preprocess(pbmc) | |
dblt = preprocess(dblt) | |
sc.pp.pca(pbmc) | |
pca_update(dblt, pbmc) | |
# Transform (where the bug occurs): | |
umap = UMAP() | |
pbmc.obsm["X_umap"] = umap.fit_transform(pbmc.obsm["X_pca"]) | |
assert np.allclose(pbmc.obsm["X_umap"], umap.embedding_) | |
dblt.obsm["X_umap"] = umap.transform(dblt.obsm["X_pca"]) | |
assert np.allclose(pbmc.obsm["X_umap"], umap.embedding_) # This errors. I'm pretty sure this shouldn't error | |
# Alternative formulation, which doesn't error and looks fine: | |
umap = UMAP() | |
pbmc_umap = umap.fit_transform(pbmc.obsm["X_pca"]) | |
assert np.allclose(pbmc_umap, umap.embedding_) | |
dblt_umap = umap.transform(dblt.obsm["X_pca"]) | |
assert np.allclose(pbmc_umap, umap.embedding_) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment