Created
December 5, 2023 15:18
-
-
Save sergpolly/ee39a452c1e30f12d5100b28f35f4ee0 to your computer and use it in GitHub Desktop.
flexible saddles by distance
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import numpy as np | |
from itertools import combinations | |
from functools import partial | |
from scipy.linalg import toeplitz | |
import pandas as pd | |
from cooltools.lib.checks import ( | |
is_compatible_viewframe, | |
is_valid_expected, | |
is_cooler_balanced, | |
is_track, | |
) | |
from cooltools.lib.common import view_from_track, align_track_with_cooler | |
from cooltools.api.saddle import digitize, _make_cis_obsexp_fetcher, _make_trans_obsexp_fetcher | |
import warnings | |
# S, C now have to be a 3-dimensional arrays of n_bins X n_bins X max_dist | |
# S(D,n,n) - first dim correspond to a certain distance (in bins) and the latter 2 are the typical saddles n*n | |
# C(D,n,n) - same thing here as well ... | |
def _accumulate_dist( | |
S, C, getmatrix, digitized, reg1, reg2, max_dist, verbose=False | |
): | |
""" | |
Helper function to aggregate across region pairs. | |
If regions are identical does it by distance ! | |
""" | |
if reg1 != reg2: | |
raise ValueError("this is special version of accumulate for cis data only ...") | |
n_bins = S.shape[-1] | |
matrix = getmatrix(reg1, reg2) | |
if verbose: | |
print("regions {} vs {}".format(reg1, reg2)) | |
# Toeplitz matrix with distance for every pixel ... | |
_dist_vec = np.arange(matrix.shape[0]) | |
dist_mat = np.abs(_dist_vec[None,:] - _dist_vec[:,None]) | |
for i in range(n_bins): | |
row_mask = digitized[reg1] == i | |
for j in range(n_bins): | |
col_mask = digitized[reg2] == j | |
data = matrix[row_mask, :][:, col_mask] | |
dist = dist_mat[row_mask, :][:, col_mask] | |
is_finite_mask = np.isfinite(data) | |
data = data[is_finite_mask] | |
dist = dist[is_finite_mask] | |
# S unrolled by distances - inefficient memory access - isn't it ? | |
S[:, i, j] += np.bincount(dist, weights=data, minlength=max_dist) | |
# C unrolled by distances | |
C[:, i, j] += np.bincount(dist, minlength=max_dist).astype(float) | |
def saddle_dist( | |
clr, | |
expected, | |
track, | |
contact_type, | |
n_bins, | |
vrange=None, | |
qrange=None, | |
view_df=None, | |
clr_weight_name="weight", | |
expected_value_col="balanced.avg", | |
view_name_col="name", | |
# min_diag=3, | |
max_dist=100, | |
# max_diag=-1, | |
trim_outliers=False, | |
verbose=False, | |
drop_track_na=False, | |
): | |
""" | |
Get a matrix of average interactions between genomic bin | |
pairs as a function of a specified genomic track. | |
The provided genomic track is either: | |
(a) digitized inside this function by passing 'n_bins', and one of 'v_range' or 'q_range' | |
(b) passed as a pre-digitized track with a categorical value column as generated by `get_digitized()`. | |
Parameters | |
---------- | |
clr : cooler.Cooler | |
Observed matrix. | |
expected : DataFrame in expected format | |
Diagonal summary statistics for each chromosome, and name of the column | |
with the values of expected to use. | |
contact_type : str | |
If 'cis' then only cis interactions are used to build the matrix. | |
If 'trans', only trans interactions are used. | |
track : DataFrame | |
A track, i.e. BedGraph-like dataframe, which is digitized with | |
the options n_bins, vrange and qrange. Can optionally be passed | |
as a pre-digitized dataFrame with a categorical value column, | |
as generated by get_digitzied(), also passing n_bins as None. | |
n_bins : int or None | |
number of bins for signal quantization. If None, then track must | |
be passed as a pre-digitized track. | |
vrange : tuple | |
Low and high values used for binning track values. | |
See get_digitized(). | |
qrange : tuple | |
Low and high values for quantile binning track values. | |
Low must be 0.0 or more, high must be 1.0 or less. | |
Only one of vrange or qrange can be passed. See get_digitzed(). | |
view_df: viewframe | |
Viewframe with genomic regions. If none, generate from track chromosomes. | |
clr_weight_name : str | |
Name of the column in the clr.bins to use as balancing weights. | |
Using raw unbalanced data is not supported for saddles. | |
expected_value_col : str | |
Name of the column in expected used for normalizing. | |
view_name_col : str | |
Name of column in view_df with region names. | |
min_diag : int | |
Smallest diagonal to include in computation. Ignored with | |
contact_type=trans. | |
max_diag : int | |
Biggest diagonal to include in computation. Ignored with | |
contact_type=trans. | |
trim_outliers : bool, optional | |
Remove first and last row and column from the output matrix. | |
verbose : bool, optional | |
If True then reports progress. | |
drop_track_na : bool, optional | |
If True then drops NaNs in input track (as if they were missing), | |
If False then counts NaNs as present in dataframe. | |
In general, this only adds check form chromosomes that have all missing values, but does not affect the results. | |
Returns | |
------- | |
interaction_sum : 2D array | |
The matrix of summed interaction probability between two genomic bins | |
given their values of the provided genomic track. | |
interaction_count : 2D array | |
The matrix of the number of genomic bin pairs that contributed to the | |
corresponding pixel of ``interaction_sum``. | |
""" | |
if type(n_bins) is int: | |
# perform digitization | |
track = align_track_with_cooler( | |
track, | |
clr, | |
view_df=view_df, | |
clr_weight_name=clr_weight_name, | |
mask_clr_bad_bins=True, | |
drop_track_na=drop_track_na, # this adds check for chromosomes that have all missing values | |
) | |
digitized_track, binedges = digitize( | |
track.iloc[:, :4], | |
n_bins, | |
vrange=vrange, | |
qrange=qrange, | |
digitized_suffix=".d", | |
) | |
digitized_col = digitized_track.columns[3] | |
elif n_bins is None: | |
# assume and test if track is pre-digitized | |
digitized_track = track | |
digitized_col = digitized_track.columns[3] | |
is_track(track.astype({digitized_col: "float"}), raise_errors=True) | |
if ( | |
type(digitized_track.dtypes[3]) | |
is not pd.core.dtypes.dtypes.CategoricalDtype | |
): | |
raise ValueError( | |
"when n_bins=None, saddle assumes the track has been " | |
+ "pre-digitized and the value column is a " | |
+ "pandas categorical. See get_digitized()." | |
) | |
cats = digitized_track[digitized_col].dtype.categories.values | |
# cats has two additional categories, 0 and n_bins+1, for values | |
# falling outside range, as well as -1 for NAs. | |
n_bins = len(cats[cats > -1]) - 2 | |
else: | |
raise ValueError("n_bins must be provided as int or None") | |
if view_df is None: | |
view_df = view_from_track(digitized_track) | |
else: | |
# Make sure view_df is a proper viewframe | |
try: | |
_ = is_compatible_viewframe( | |
view_df, | |
clr, | |
check_sorting=True, # just in case | |
raise_errors=True, | |
) | |
except Exception as e: | |
raise ValueError("view_df is not a valid viewframe or incompatible") from e | |
# make sure provided expected is compatible | |
try: | |
_ = is_valid_expected( | |
expected, | |
contact_type, | |
view_df, | |
verify_cooler=clr, | |
expected_value_cols=[ | |
expected_value_col, | |
], | |
raise_errors=True, | |
) | |
except Exception as e: | |
raise ValueError("provided expected is not compatible") from e | |
# check if cooler is balanced | |
if clr_weight_name: | |
try: | |
_ = is_cooler_balanced(clr, clr_weight_name, raise_errors=True) | |
except Exception as e: | |
raise ValueError( | |
f"provided cooler is not balanced or {clr_weight_name} is missing" | |
) from e | |
digitized_tracks = {} | |
for num, reg in view_df.iterrows(): | |
digitized_reg = bioframe.select(digitized_track, reg) | |
digitized_tracks[reg[view_name_col]] = digitized_reg[digitized_col] | |
# set "cis" or "trans" for supports (regions to iterate over) and matrix fetcher | |
if contact_type == "cis": | |
# only symmetric intra-chromosomal regions : | |
supports = list(zip(view_df[view_name_col], view_df[view_name_col])) | |
getmatrix = _make_cis_obsexp_fetcher( | |
clr, | |
expected, | |
view_df, | |
view_name_col=view_name_col, | |
expected_value_col=expected_value_col, | |
clr_weight_name=clr_weight_name, | |
) | |
# n_bins here includes 2 open bins for values <lo and >hi. | |
interaction_sum = np.zeros((max_dist, n_bins + 2, n_bins + 2)) | |
interaction_count = np.zeros((max_dist, n_bins + 2, n_bins + 2)) | |
for reg1, reg2 in supports: | |
_accumulate_dist( | |
interaction_sum, | |
interaction_count, | |
getmatrix, | |
digitized_tracks, | |
reg1, | |
reg2, | |
max_dist, | |
verbose=verbose | |
) | |
if trim_outliers: | |
interaction_sum = interaction_sum[:, 1:-1, 1:-1] | |
interaction_count = interaction_count[:, 1:-1, 1:-1] | |
# ... explore the S + S.T ?! | |
return interaction_sum, interaction_count | |
elif contact_type == "trans": | |
# asymmetric inter-chromosomal regions : | |
supports = list(combinations(view_df[view_name_col], 2)) | |
supports = [ | |
i | |
for i in supports | |
if ( | |
view_df["chrom"].loc[view_df[view_name_col] == i[0]].values | |
!= view_df["chrom"].loc[view_df[view_name_col] == i[1]].values | |
) | |
] | |
getmatrix = _make_trans_obsexp_fetcher( | |
clr, | |
expected, | |
view_df, | |
view_name_col=view_name_col, | |
expected_value_col=expected_value_col, | |
clr_weight_name=clr_weight_name, | |
) | |
# n_bins here includes 2 open bins for values <lo and >hi. | |
interaction_sum = np.zeros((n_bins + 2, n_bins + 2)) | |
interaction_count = np.zeros((n_bins + 2, n_bins + 2)) | |
for reg1, reg2 in supports: | |
_accumulate( | |
interaction_sum, | |
interaction_count, | |
getmatrix, | |
digitized_tracks, | |
reg1, | |
reg2, | |
min_diag=min_diag, | |
max_diag=max_diag, | |
verbose=verbose, | |
) | |
interaction_sum += interaction_sum.T | |
interaction_count += interaction_count.T | |
if trim_outliers: | |
interaction_sum = interaction_sum[:, 1:-1, 1:-1] | |
interaction_count = interaction_count[:, 1:-1, 1:-1] | |
return interaction_sum, interaction_count | |
else: | |
raise ValueError("Allowed values for contact_type are 'cis' or 'trans'.") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment