Last active
January 18, 2022 20:01
-
-
Save gibsramen/2b1a6cf30e77b1df348acb01a157e81d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Tuple | |
import warnings | |
import biom | |
import click | |
from gemelli.rpca import rpca | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import seaborn as sns | |
from skbio.tree import TreeNode | |
from skbio.diversity import beta_diversity | |
from sklearn.decomposition import PCA | |
from sklearn.preprocessing import StandardScaler | |
from skbio.stats.ordination import pcoa | |
warnings.filterwarnings("ignore") | |
def run_pca(table: biom.Table) -> pd.DataFrame: | |
values = table.matrix_data.todense().T | |
ss_data = StandardScaler().fit_transform(values) | |
pca = PCA(n_components=3) | |
pc_df = pd.DataFrame(pca.fit_transform(ss_data)) | |
pc_df.index = table.ids() | |
pc_df.columns = [f"PC{x+1}" for x in range(3)] | |
return pc_df | |
def run_pcoa( | |
table: biom.Table, | |
metric: str = "braycurtis", | |
tree: TreeNode = None | |
) -> pd.DataFrame: | |
values = table.matrix_data.todense().T | |
if "unifrac" in metric: | |
otu_ids = table.ids("observation") | |
# NOTE: Tree | |
dm = beta_diversity(metric, values, table.ids(), otu_ids=otu_ids, | |
tree=tree) | |
else: | |
dm = beta_diversity(metric, values, table.ids()) | |
pcoa_res = pcoa(dm, number_of_dimensions=3) | |
return pcoa_res.samples | |
def run_rpca( | |
table: biom.Table, | |
) -> Tuple[pd.DataFrame, pd.DataFrame]: | |
ord_res, dm = rpca(table, min_feature_count=2) | |
feats = ord_res.features | |
feats["magnitude"] = feats.apply(np.linalg.norm, axis=1) | |
return (ord_res.samples, ord_res.features) | |
@click.command() | |
@click.option("--table", required=True) | |
@click.option("--metadata", required=True) | |
@click.option("--column", required=True) | |
@click.option("--tree", required=False) | |
@click.option("--beta-metric-horseshoe", default="jaccard", required=False) | |
@click.option("--beta-metric-pcoa", default="braycurtis", required=False) | |
@click.option("--rpca-arrows", default=5, required=False) | |
@click.option("--rpca-arrows-scale", default=1, required=False) | |
@click.option("--output", required=True) | |
def generate_panels( | |
table, | |
metadata, | |
column, | |
tree, | |
beta_metric_horseshoe, | |
beta_metric_pcoa, | |
rpca_arrows, | |
rpca_arrows_scale, | |
output | |
): | |
"""Generate panels for DR review. | |
(1) Spikes | |
(2) Horseshoe | |
(3) PCoA | |
(4) RPCA Biplot | |
""" | |
md = pd.read_table(metadata, sep="\t", index_col=0, | |
na_values=["Missing: Not provided"]) | |
md = md.dropna(subset=[column]) | |
tbl = biom.load_table(table) | |
if tree is not None: | |
print("Loading tree...") | |
tree = TreeNode.read(tree) | |
print("Tree loaded!") | |
tbl_ids = tbl.ids().astype(str) | |
md.index = md.index.astype(str) | |
samps_to_keep = set(md.index).intersection(tbl_ids) | |
md = md.loc[samps_to_keep] | |
tbl.filter(samps_to_keep) | |
fig, axs = plt.subplots(2, 2, dpi=300, facecolor="white") | |
args = {"x": "PC1", "y": "PC2", "hue": column, "legend": False, | |
"s": 20, "edgecolor": None} | |
# PCA | |
print("Running PCA...") | |
pca_data = run_pca(tbl).join(md) | |
sns.scatterplot(data=pca_data, **args, ax=axs[0, 0]) | |
axs[0, 0].set_title("PCA") | |
# Horseshoe | |
print("Running horseshoe...") | |
horseshoe_data = run_pcoa(tbl, metric=beta_metric_horseshoe, | |
tree=tree).join(md) | |
sns.scatterplot(data=horseshoe_data, **args, ax=axs[0, 1]) | |
axs[0, 1].set_title(f"PCoA ({beta_metric_horseshoe})") | |
# PCoA | |
print("Running PCoA...") | |
pcoa_data = run_pcoa(tbl, metric=beta_metric_pcoa, tree=tree).join(md) | |
sns.scatterplot(data=pcoa_data, **args, ax=axs[1, 0]) | |
axs[1, 0].set_title(f"PCoA ({beta_metric_pcoa})") | |
# RPCA | |
print("Running RPCA...") | |
rpca_sample_data, rpca_feature_data = run_rpca(tbl) | |
rpca_sample_data = rpca_sample_data.join(md) | |
rpca_feature_data = rpca_feature_data.sort_values( | |
by="magnitude", | |
ascending=False | |
) | |
sns.scatterplot(data=rpca_sample_data, **args, ax=axs[1, 1]) | |
for i in range(rpca_arrows): | |
axs[1, 1].arrow( | |
x=0, | |
y=0, | |
dx=rpca_feature_data.iloc[i]["PC1"]*rpca_arrows_scale, | |
dy=rpca_feature_data.iloc[i]["PC2"]*rpca_arrows_scale | |
) | |
axs[1, 1].set_title("RPCA") | |
for ax in axs.flatten(): | |
ax.xaxis.set_visible(False) | |
ax.yaxis.set_visible(False) | |
plt.savefig(output) | |
print("Saved!") | |
if __name__ == "__main__": | |
generate_panels() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment