Created
December 10, 2021 04:01
-
-
Save timothymillar/765adee5a4f6c14e7d218db4500caeb1 to your computer and use it in GitHub Desktop.
Test cohort reductions for sgkit
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import dask.array as da | |
from numba import guvectorize | |
from sgkit.typing import ArrayLike | |
import sgkit as sg | |
from typing import Any, Callable, Hashable, Iterable, Optional, Tuple, Union | |
def cohort_reduction(statistic: Callable)-> Callable: | |
def cohort_statistic(values: ArrayLike, cohorts: ArrayLike, sample_axis: int = 1) -> ArrayLike: | |
"""Apply a reduction operation to the samples of each cohort. | |
Parameters | |
---------- | |
values | |
N-dimentional array with a samples dimention. | |
cohorts | |
Array of integers indicating the cohort each sample is assigned to. | |
sample_axis | |
Integer indicating samples dimention in the values array (default = 1). | |
Returns | |
------- | |
Array of results for each cohort. This is the same shape as the values array | |
with the exception that the samples axis now has length equal to the number of | |
cohorts. | |
""" | |
values = da.array(values) | |
cohorts = np.array(cohorts) | |
assert cohorts.ndim == 1 | |
n_cohorts = cohorts.max() + 1 | |
shape = list(values.shape) | |
chunks = list(values.chunks) | |
shape[sample_axis] = n_cohorts | |
chunks[sample_axis] = (n_cohorts, ) | |
shape = tuple(shape) | |
chunks = tuple(chunks) | |
dummy = da.zeros(shape, chunks=chunks, dtype=np.uint64) | |
out = da.map_blocks( | |
statistic, | |
values, | |
cohorts, | |
dummy, | |
chunks=chunks, | |
drop_axis=sample_axis, | |
new_axis=sample_axis, | |
dtype=np.float64, | |
axes=[sample_axis,0,sample_axis,sample_axis], | |
) | |
return out | |
cohort_statistic.__name__ = statistic.__name__ | |
return cohort_statistic | |
@cohort_reduction | |
@guvectorize( # type: ignore | |
[ | |
"void(int32[:], int64[:], uint64[:], int64[:])", | |
"void(int64[:], int64[:], uint64[:], int64[:])", | |
"void(float32[:], int64[:], uint64[:], float64[:])", | |
"void(float64[:], int64[:], uint64[:], float64[:])", | |
], | |
"(n),(n),(c)->(c)", | |
nopython=True, | |
cache=False, | |
) | |
def cohort_sum( | |
values: ArrayLike, cohorts: ArrayLike, _: ArrayLike, out: ArrayLike | |
) -> None: # pragma: no cover | |
out[:] = 0 | |
n_samples = len(values) | |
for i in range(n_samples): | |
v = values[i] | |
c = cohorts[i] | |
if c >= 0: | |
out[c] += v | |
@cohort_reduction | |
@guvectorize( # type: ignore | |
[ | |
"void(int32[:], int64[:], uint64[:], int64[:])", | |
"void(int64[:], int64[:], uint64[:], int64[:])", | |
"void(float32[:], int64[:], uint64[:], float64[:])", | |
"void(float64[:], int64[:], uint64[:], float64[:])", | |
], | |
"(n),(n),(c)->(c)", | |
nopython=True, | |
cache=False, | |
) | |
def cohort_nansum( | |
values: ArrayLike, cohorts: ArrayLike, _: ArrayLike, out: ArrayLike | |
) -> None: # pragma: no cover | |
out[:] = 0 | |
n_samples = len(values) | |
for i in range(n_samples): | |
v = values[i] | |
if not np.isnan(v): | |
c = cohorts[i] | |
if c >= 0: | |
out[c] += v | |
@cohort_reduction | |
@guvectorize( # type: ignore | |
[ | |
"void(float32[:], int64[:], uint64[:], float64[:])", | |
"void(float64[:], int64[:], uint64[:], float64[:])", | |
], | |
"(n),(n),(c)->(c)", | |
nopython=True, | |
cache=False, | |
) | |
def cohort_mean( | |
values: ArrayLike, cohorts: ArrayLike, counts: ArrayLike, out: ArrayLike | |
) -> None: # pragma: no cover | |
out[:] = 0 | |
counts[:] = 0 | |
n_samples = len(values) | |
n_cohorts = len(out) | |
for i in range(n_samples): | |
v = values[i] | |
c = cohorts[i] | |
if c >= 0: | |
out[c] += v | |
counts[c] += 1 | |
for j in range(n_cohorts): | |
n = counts[j] | |
if n != 0: | |
out[j] /= n | |
else: | |
out[j] = np.nan | |
@cohort_reduction | |
@guvectorize( # type: ignore | |
[ | |
"void(float32[:], int64[:], uint64[:], float64[:])", | |
"void(float64[:], int64[:], uint64[:], float64[:])", | |
], | |
"(n),(n),(c)->(c)", | |
nopython=True, | |
cache=False, | |
) | |
def cohort_nanmean( | |
values: ArrayLike, cohorts: ArrayLike, counts: ArrayLike, out: ArrayLike | |
) -> None: # pragma: no cover | |
out[:] = 0 | |
counts[:] = 0 | |
n_samples = len(values) | |
n_cohorts = len(out) | |
for i in range(n_samples): | |
v = values[i] | |
if not np.isnan(v): | |
c = cohorts[i] | |
if c >= 0: | |
out[c] += v | |
counts[c] += 1 | |
for j in range(n_cohorts): | |
n = counts[j] | |
if n != 0: | |
out[j] /= n | |
else: | |
out[j] = np.nan | |
# reference implimentation | |
def cohort_statistic( | |
values: ArrayLike, | |
statistic: Callable[..., ArrayLike], | |
cohorts: ArrayLike, | |
sample_axis: int = 1, | |
**kwargs: Any, | |
) -> da.Array: | |
values = da.asarray(values) | |
cohorts = np.array(cohorts) | |
n_cohorts = cohorts.max() + 1 | |
idx = [cohorts == c for c in range(n_cohorts)] | |
seq = [da.take(values, i, axis=sample_axis) for i in idx] | |
out = da.stack([statistic(c, **kwargs) for c in seq], axis=sample_axis) | |
return out | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import dask.array as da | |
import sgkit as sg | |
import pytest | |
from cohorts import cohort_statistic, cohort_sum, cohort_nansum, cohort_mean, cohort_nanmean | |
@pytest.mark.parametrize("cohort_func, generic_func,", [ | |
(cohort_sum, np.sum), | |
(cohort_nansum, np.nansum), | |
(cohort_mean, np.mean), | |
(cohort_nanmean, np.nanmean), | |
]) | |
@pytest.mark.parametrize("seed", [1, 2]) | |
@pytest.mark.parametrize("n_sample", [7, 50]) | |
@pytest.mark.parametrize("n_cohort", [1, 2, 7]) | |
def test_1d(cohort_func, generic_func, n_sample, n_cohort, seed): | |
np.random.seed(seed) | |
sample_cohort = np.random.randint(0, n_cohort, size=n_sample) | |
sample_value = np.random.rand(n_sample) | |
np.testing.assert_array_almost_equal( | |
cohort_func(sample_value, sample_cohort, sample_axis=0), | |
cohort_statistic(sample_value, generic_func, sample_cohort, axis=0, sample_axis=0), | |
) | |
@pytest.mark.parametrize("cohort_func, generic_func,", [ | |
(cohort_sum, np.sum), | |
(cohort_nansum, np.nansum), | |
(cohort_mean, np.mean), | |
(cohort_nanmean, np.nanmean), | |
]) | |
@pytest.mark.parametrize("chunks", [((10,), (7,)), ((1000,), (5,5)), ((100,) *10, (20,20,20))]) | |
@pytest.mark.parametrize("seed", [1, 2]) | |
@pytest.mark.parametrize("n_allele", [2, 5]) | |
@pytest.mark.parametrize("n_ploidy", [2, 4]) | |
@pytest.mark.parametrize("n_cohorts", [1, 2, 3]) | |
def test_2d(cohort_func, generic_func, chunks, seed, n_allele, n_ploidy, n_cohorts): | |
n_variant=sum(chunks[0]) | |
n_sample=sum(chunks[1]) | |
ds = sg.simulate_genotype_call_dataset( | |
n_variant=n_variant, | |
n_sample=n_sample, | |
n_allele=n_allele, | |
n_ploidy=n_ploidy, | |
seed=seed, | |
missing_pct=0.1, | |
) | |
np.random.seed(seed) | |
cohorts = np.random.randint(0, n_cohorts, size=n_sample) | |
ds["sample_cohort"] = "samples", cohorts | |
ds = sg.individual_heterozygosity(ds) | |
ds["call_heterozygosity"] = ds.call_heterozygosity.chunk(chunks) | |
np.testing.assert_array_almost_equal( | |
cohort_func(ds.call_heterozygosity, ds.sample_cohort), | |
cohort_statistic(ds.call_heterozygosity, generic_func, ds.sample_cohort, axis=1), | |
) | |
@pytest.mark.parametrize("cohort_func, generic_func,", [ | |
(cohort_sum, np.sum), | |
(cohort_nansum, np.nansum), | |
(cohort_mean, np.mean), | |
(cohort_nanmean, np.nanmean), | |
]) | |
@pytest.mark.parametrize("chunks", [((10,), (7,), (2,))]) | |
@pytest.mark.parametrize("seed", [1, 2]) | |
@pytest.mark.parametrize("n_ploidy", [2, 4]) | |
@pytest.mark.parametrize("n_cohorts", [1, 2, 3]) | |
def test_3d(cohort_func, generic_func, chunks, seed, n_ploidy, n_cohorts): | |
n_variant=sum(chunks[0]) | |
n_sample=sum(chunks[1]) | |
n_allele=sum(chunks[2]) | |
ds = sg.simulate_genotype_call_dataset( | |
n_variant=n_variant, | |
n_sample=n_sample, | |
n_allele=n_allele, | |
n_ploidy=n_ploidy, | |
seed=seed, | |
missing_pct=0.1, | |
) | |
np.random.seed(seed) | |
cohorts = np.random.randint(0, n_cohorts, size=n_sample) | |
ds["sample_cohort"] = "samples", cohorts | |
ds = sg.call_allele_frequencies(ds) | |
ds["call_allele_frequency"] = ds.call_allele_frequency.chunk(chunks) | |
np.testing.assert_array_almost_equal( | |
cohort_func(ds.call_allele_frequency, ds.sample_cohort), | |
cohort_statistic(ds.call_allele_frequency, generic_func, ds.sample_cohort, axis=1), | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment