Skip to content

Instantly share code, notes, and snippets.

@ivirshup
Last active April 14, 2019 06:35
Show Gist options
  • Save ivirshup/3b665d710729cef59b2cc80c0765cea8 to your computer and use it in GitHub Desktop.
Save ivirshup/3b665d710729cef59b2cc80c0765cea8 to your computer and use it in GitHub Desktop.
Weird data projection bug
import scanpy as sc
import numpy as np
import pandas as pd
from scipy import sparse
from umap import UMAP
from itertools import repeat, chain
def preprocess(adata):
# adata.var["mito"] = adata.var["gene_symbols"].str.startswith("MT-")
# sc.pp.calculate_qc_metrics(adata, qc_vars=["mito"], inplace=True)
sc.pp.normalize_per_cell(adata, counts_per_cell_after=10000)
sc.pp.log1p(adata)
return adata
def pca_update(tgt, src, inplace=True):
# TODO: Make sure we know the settings from src
if not inplace:
tgt = tgt.copy()
if sparse.issparse(tgt.X):
X = tgt.X.toarray()
else:
X = tgt.X.copy()
X -= np.asarray(tgt.X.mean(axis=0))
tgt_pca = np.dot(X, src.varm["PCs"])
tgt.obsm["X_pca"] = tgt_pca
return tgt
def simulate_doublets(adata, frac=.5):
"""Simulate doublets from count data.
Params
------
adata
The anndata object to sample from. Must have count data.
frac
Fraction of total cells to simulate.
"""
m, n = adata.X.shape
n_doublets = int(np.round(m * frac))
pos_idx = np.array(list(chain.from_iterable(map(lambda x: repeat(x, 2), range(n_doublets)))))
combos = np.random.randint(0, m, (n_doublets * 2))
pos = sparse.csr_matrix(
(np.ones_like(combos, dtype=adata.X.dtype), (pos_idx, combos)),
shape=(n_doublets, m)
)
dblX = pos * adata.X
# TODO: Downsample total counts
srcs = np.sort(combos.reshape(n_doublets, 2), axis=1)
obs = pd.DataFrame(srcs, columns=["src1", "src2"])
var = pd.DataFrame(index=adata.var_names)
return sc.AnnData(dblX, obs=obs, var=var)
# Setup
pbmc = sc.datasets.pbmc3k()
sc.pp.filter_genes(pbmc, min_counts=1)
pbmc.var["gene_symbols"] = pbmc.var.index
pbmc.var.set_index("gene_ids", inplace=True)
dblt = simulate_doublets(pbmc, 1) # Does not happen for `frac!=1` afaik
dblt.var["gene_symbols"] = pbmc.var["gene_symbols"]
pbmc = preprocess(pbmc)
dblt = preprocess(dblt)
sc.pp.pca(pbmc)
pca_update(dblt, pbmc)
# Transform (where the bug occurs):
umap = UMAP()
pbmc.obsm["X_umap"] = umap.fit_transform(pbmc.obsm["X_pca"])
assert np.allclose(pbmc.obsm["X_umap"], umap.embedding_)
dblt.obsm["X_umap"] = umap.transform(dblt.obsm["X_pca"])
assert np.allclose(pbmc.obsm["X_umap"], umap.embedding_) # This errors. I'm pretty sure this shouldn't error
# Alternative formulation, which doesn't error and looks fine:
umap = UMAP()
pbmc_umap = umap.fit_transform(pbmc.obsm["X_pca"])
assert np.allclose(pbmc_umap, umap.embedding_)
dblt_umap = umap.transform(dblt.obsm["X_pca"])
assert np.allclose(pbmc_umap, umap.embedding_)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment