Skip to content

Instantly share code, notes, and snippets.

@timothymillar
Created December 10, 2021 04:01
Show Gist options
  • Save timothymillar/765adee5a4f6c14e7d218db4500caeb1 to your computer and use it in GitHub Desktop.
Save timothymillar/765adee5a4f6c14e7d218db4500caeb1 to your computer and use it in GitHub Desktop.
Test cohort reductions for sgkit
import numpy as np
import dask.array as da
from numba import guvectorize
from sgkit.typing import ArrayLike
import sgkit as sg
from typing import Any, Callable, Hashable, Iterable, Optional, Tuple, Union
def cohort_reduction(statistic: Callable)-> Callable:
def cohort_statistic(values: ArrayLike, cohorts: ArrayLike, sample_axis: int = 1) -> ArrayLike:
"""Apply a reduction operation to the samples of each cohort.
Parameters
----------
values
N-dimentional array with a samples dimention.
cohorts
Array of integers indicating the cohort each sample is assigned to.
sample_axis
Integer indicating samples dimention in the values array (default = 1).
Returns
-------
Array of results for each cohort. This is the same shape as the values array
with the exception that the samples axis now has length equal to the number of
cohorts.
"""
values = da.array(values)
cohorts = np.array(cohorts)
assert cohorts.ndim == 1
n_cohorts = cohorts.max() + 1
shape = list(values.shape)
chunks = list(values.chunks)
shape[sample_axis] = n_cohorts
chunks[sample_axis] = (n_cohorts, )
shape = tuple(shape)
chunks = tuple(chunks)
dummy = da.zeros(shape, chunks=chunks, dtype=np.uint64)
out = da.map_blocks(
statistic,
values,
cohorts,
dummy,
chunks=chunks,
drop_axis=sample_axis,
new_axis=sample_axis,
dtype=np.float64,
axes=[sample_axis,0,sample_axis,sample_axis],
)
return out
cohort_statistic.__name__ = statistic.__name__
return cohort_statistic
@cohort_reduction
@guvectorize( # type: ignore
[
"void(int32[:], int64[:], uint64[:], int64[:])",
"void(int64[:], int64[:], uint64[:], int64[:])",
"void(float32[:], int64[:], uint64[:], float64[:])",
"void(float64[:], int64[:], uint64[:], float64[:])",
],
"(n),(n),(c)->(c)",
nopython=True,
cache=False,
)
def cohort_sum(
values: ArrayLike, cohorts: ArrayLike, _: ArrayLike, out: ArrayLike
) -> None: # pragma: no cover
out[:] = 0
n_samples = len(values)
for i in range(n_samples):
v = values[i]
c = cohorts[i]
if c >= 0:
out[c] += v
@cohort_reduction
@guvectorize( # type: ignore
[
"void(int32[:], int64[:], uint64[:], int64[:])",
"void(int64[:], int64[:], uint64[:], int64[:])",
"void(float32[:], int64[:], uint64[:], float64[:])",
"void(float64[:], int64[:], uint64[:], float64[:])",
],
"(n),(n),(c)->(c)",
nopython=True,
cache=False,
)
def cohort_nansum(
values: ArrayLike, cohorts: ArrayLike, _: ArrayLike, out: ArrayLike
) -> None: # pragma: no cover
out[:] = 0
n_samples = len(values)
for i in range(n_samples):
v = values[i]
if not np.isnan(v):
c = cohorts[i]
if c >= 0:
out[c] += v
@cohort_reduction
@guvectorize( # type: ignore
[
"void(float32[:], int64[:], uint64[:], float64[:])",
"void(float64[:], int64[:], uint64[:], float64[:])",
],
"(n),(n),(c)->(c)",
nopython=True,
cache=False,
)
def cohort_mean(
values: ArrayLike, cohorts: ArrayLike, counts: ArrayLike, out: ArrayLike
) -> None: # pragma: no cover
out[:] = 0
counts[:] = 0
n_samples = len(values)
n_cohorts = len(out)
for i in range(n_samples):
v = values[i]
c = cohorts[i]
if c >= 0:
out[c] += v
counts[c] += 1
for j in range(n_cohorts):
n = counts[j]
if n != 0:
out[j] /= n
else:
out[j] = np.nan
@cohort_reduction
@guvectorize( # type: ignore
[
"void(float32[:], int64[:], uint64[:], float64[:])",
"void(float64[:], int64[:], uint64[:], float64[:])",
],
"(n),(n),(c)->(c)",
nopython=True,
cache=False,
)
def cohort_nanmean(
values: ArrayLike, cohorts: ArrayLike, counts: ArrayLike, out: ArrayLike
) -> None: # pragma: no cover
out[:] = 0
counts[:] = 0
n_samples = len(values)
n_cohorts = len(out)
for i in range(n_samples):
v = values[i]
if not np.isnan(v):
c = cohorts[i]
if c >= 0:
out[c] += v
counts[c] += 1
for j in range(n_cohorts):
n = counts[j]
if n != 0:
out[j] /= n
else:
out[j] = np.nan
# reference implimentation
def cohort_statistic(
values: ArrayLike,
statistic: Callable[..., ArrayLike],
cohorts: ArrayLike,
sample_axis: int = 1,
**kwargs: Any,
) -> da.Array:
values = da.asarray(values)
cohorts = np.array(cohorts)
n_cohorts = cohorts.max() + 1
idx = [cohorts == c for c in range(n_cohorts)]
seq = [da.take(values, i, axis=sample_axis) for i in idx]
out = da.stack([statistic(c, **kwargs) for c in seq], axis=sample_axis)
return out
import numpy as np
import dask.array as da
import sgkit as sg
import pytest
from cohorts import cohort_statistic, cohort_sum, cohort_nansum, cohort_mean, cohort_nanmean
@pytest.mark.parametrize("cohort_func, generic_func,", [
(cohort_sum, np.sum),
(cohort_nansum, np.nansum),
(cohort_mean, np.mean),
(cohort_nanmean, np.nanmean),
])
@pytest.mark.parametrize("seed", [1, 2])
@pytest.mark.parametrize("n_sample", [7, 50])
@pytest.mark.parametrize("n_cohort", [1, 2, 7])
def test_1d(cohort_func, generic_func, n_sample, n_cohort, seed):
np.random.seed(seed)
sample_cohort = np.random.randint(0, n_cohort, size=n_sample)
sample_value = np.random.rand(n_sample)
np.testing.assert_array_almost_equal(
cohort_func(sample_value, sample_cohort, sample_axis=0),
cohort_statistic(sample_value, generic_func, sample_cohort, axis=0, sample_axis=0),
)
@pytest.mark.parametrize("cohort_func, generic_func,", [
(cohort_sum, np.sum),
(cohort_nansum, np.nansum),
(cohort_mean, np.mean),
(cohort_nanmean, np.nanmean),
])
@pytest.mark.parametrize("chunks", [((10,), (7,)), ((1000,), (5,5)), ((100,) *10, (20,20,20))])
@pytest.mark.parametrize("seed", [1, 2])
@pytest.mark.parametrize("n_allele", [2, 5])
@pytest.mark.parametrize("n_ploidy", [2, 4])
@pytest.mark.parametrize("n_cohorts", [1, 2, 3])
def test_2d(cohort_func, generic_func, chunks, seed, n_allele, n_ploidy, n_cohorts):
n_variant=sum(chunks[0])
n_sample=sum(chunks[1])
ds = sg.simulate_genotype_call_dataset(
n_variant=n_variant,
n_sample=n_sample,
n_allele=n_allele,
n_ploidy=n_ploidy,
seed=seed,
missing_pct=0.1,
)
np.random.seed(seed)
cohorts = np.random.randint(0, n_cohorts, size=n_sample)
ds["sample_cohort"] = "samples", cohorts
ds = sg.individual_heterozygosity(ds)
ds["call_heterozygosity"] = ds.call_heterozygosity.chunk(chunks)
np.testing.assert_array_almost_equal(
cohort_func(ds.call_heterozygosity, ds.sample_cohort),
cohort_statistic(ds.call_heterozygosity, generic_func, ds.sample_cohort, axis=1),
)
@pytest.mark.parametrize("cohort_func, generic_func,", [
(cohort_sum, np.sum),
(cohort_nansum, np.nansum),
(cohort_mean, np.mean),
(cohort_nanmean, np.nanmean),
])
@pytest.mark.parametrize("chunks", [((10,), (7,), (2,))])
@pytest.mark.parametrize("seed", [1, 2])
@pytest.mark.parametrize("n_ploidy", [2, 4])
@pytest.mark.parametrize("n_cohorts", [1, 2, 3])
def test_3d(cohort_func, generic_func, chunks, seed, n_ploidy, n_cohorts):
n_variant=sum(chunks[0])
n_sample=sum(chunks[1])
n_allele=sum(chunks[2])
ds = sg.simulate_genotype_call_dataset(
n_variant=n_variant,
n_sample=n_sample,
n_allele=n_allele,
n_ploidy=n_ploidy,
seed=seed,
missing_pct=0.1,
)
np.random.seed(seed)
cohorts = np.random.randint(0, n_cohorts, size=n_sample)
ds["sample_cohort"] = "samples", cohorts
ds = sg.call_allele_frequencies(ds)
ds["call_allele_frequency"] = ds.call_allele_frequency.chunk(chunks)
np.testing.assert_array_almost_equal(
cohort_func(ds.call_allele_frequency, ds.sample_cohort),
cohort_statistic(ds.call_allele_frequency, generic_func, ds.sample_cohort, axis=1),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment