Skip to content

Instantly share code, notes, and snippets.

@gibsramen
Last active January 18, 2022 20:01
Show Gist options
  • Save gibsramen/2b1a6cf30e77b1df348acb01a157e81d to your computer and use it in GitHub Desktop.
Save gibsramen/2b1a6cf30e77b1df348acb01a157e81d to your computer and use it in GitHub Desktop.
from typing import Tuple
import warnings
import biom
import click
from gemelli.rpca import rpca
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from skbio.tree import TreeNode
from skbio.diversity import beta_diversity
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from skbio.stats.ordination import pcoa
warnings.filterwarnings("ignore")
def run_pca(table: biom.Table) -> pd.DataFrame:
values = table.matrix_data.todense().T
ss_data = StandardScaler().fit_transform(values)
pca = PCA(n_components=3)
pc_df = pd.DataFrame(pca.fit_transform(ss_data))
pc_df.index = table.ids()
pc_df.columns = [f"PC{x+1}" for x in range(3)]
return pc_df
def run_pcoa(
table: biom.Table,
metric: str = "braycurtis",
tree: TreeNode = None
) -> pd.DataFrame:
values = table.matrix_data.todense().T
if "unifrac" in metric:
otu_ids = table.ids("observation")
# NOTE: Tree
dm = beta_diversity(metric, values, table.ids(), otu_ids=otu_ids,
tree=tree)
else:
dm = beta_diversity(metric, values, table.ids())
pcoa_res = pcoa(dm, number_of_dimensions=3)
return pcoa_res.samples
def run_rpca(
table: biom.Table,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
ord_res, dm = rpca(table, min_feature_count=2)
feats = ord_res.features
feats["magnitude"] = feats.apply(np.linalg.norm, axis=1)
return (ord_res.samples, ord_res.features)
@click.command()
@click.option("--table", required=True)
@click.option("--metadata", required=True)
@click.option("--column", required=True)
@click.option("--tree", required=False)
@click.option("--beta-metric-horseshoe", default="jaccard", required=False)
@click.option("--beta-metric-pcoa", default="braycurtis", required=False)
@click.option("--rpca-arrows", default=5, required=False)
@click.option("--rpca-arrows-scale", default=1, required=False)
@click.option("--output", required=True)
def generate_panels(
table,
metadata,
column,
tree,
beta_metric_horseshoe,
beta_metric_pcoa,
rpca_arrows,
rpca_arrows_scale,
output
):
"""Generate panels for DR review.
(1) Spikes
(2) Horseshoe
(3) PCoA
(4) RPCA Biplot
"""
md = pd.read_table(metadata, sep="\t", index_col=0,
na_values=["Missing: Not provided"])
md = md.dropna(subset=[column])
tbl = biom.load_table(table)
if tree is not None:
print("Loading tree...")
tree = TreeNode.read(tree)
print("Tree loaded!")
tbl_ids = tbl.ids().astype(str)
md.index = md.index.astype(str)
samps_to_keep = set(md.index).intersection(tbl_ids)
md = md.loc[samps_to_keep]
tbl.filter(samps_to_keep)
fig, axs = plt.subplots(2, 2, dpi=300, facecolor="white")
args = {"x": "PC1", "y": "PC2", "hue": column, "legend": False,
"s": 20, "edgecolor": None}
# PCA
print("Running PCA...")
pca_data = run_pca(tbl).join(md)
sns.scatterplot(data=pca_data, **args, ax=axs[0, 0])
axs[0, 0].set_title("PCA")
# Horseshoe
print("Running horseshoe...")
horseshoe_data = run_pcoa(tbl, metric=beta_metric_horseshoe,
tree=tree).join(md)
sns.scatterplot(data=horseshoe_data, **args, ax=axs[0, 1])
axs[0, 1].set_title(f"PCoA ({beta_metric_horseshoe})")
# PCoA
print("Running PCoA...")
pcoa_data = run_pcoa(tbl, metric=beta_metric_pcoa, tree=tree).join(md)
sns.scatterplot(data=pcoa_data, **args, ax=axs[1, 0])
axs[1, 0].set_title(f"PCoA ({beta_metric_pcoa})")
# RPCA
print("Running RPCA...")
rpca_sample_data, rpca_feature_data = run_rpca(tbl)
rpca_sample_data = rpca_sample_data.join(md)
rpca_feature_data = rpca_feature_data.sort_values(
by="magnitude",
ascending=False
)
sns.scatterplot(data=rpca_sample_data, **args, ax=axs[1, 1])
for i in range(rpca_arrows):
axs[1, 1].arrow(
x=0,
y=0,
dx=rpca_feature_data.iloc[i]["PC1"]*rpca_arrows_scale,
dy=rpca_feature_data.iloc[i]["PC2"]*rpca_arrows_scale
)
axs[1, 1].set_title("RPCA")
for ax in axs.flatten():
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
plt.savefig(output)
print("Saved!")
if __name__ == "__main__":
generate_panels()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment