Last active
January 13, 2021 05:27
-
-
Save czgdp1807/879223647496d9cef0b246272731c3db to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/sympy/stats/sampling/tests/test_sample_continuous_rv.py b/sympy/stats/sampling/tests/test_sample_continuous_rv.py | |
index 7f13b4061e..46993145cb 100644 | |
--- a/sympy/stats/sampling/tests/test_sample_continuous_rv.py | |
+++ b/sympy/stats/sampling/tests/test_sample_continuous_rv.py | |
@@ -1,6 +1,7 @@ | |
+from sympy import exp, Interval, oo, Symbol | |
from sympy.external import import_module | |
from sympy.stats import Beta, Chi, Normal, Gamma, Exponential, LogNormal, Pareto, ChiSquared, Uniform, sample, \ | |
- BetaPrime, Cauchy, GammaInverse, GaussianInverse, StudentT | |
+ BetaPrime, Cauchy, GammaInverse, GaussianInverse, StudentT, Weibull, density, ContinuousRV | |
from sympy.testing.pytest import skip, ignore_warnings, raises | |
@@ -92,3 +93,91 @@ def test_sample_pymc3(): | |
assert sam in X.pspace.domain.set | |
raises(NotImplementedError, | |
lambda: next(sample(Chi("C", 1), library='pymc3'))) | |
+ | |
+ | |
+def test_sampling_gamma_inverse(): | |
+ scipy = import_module('scipy') | |
+ if not scipy: | |
+ skip('Scipy not installed. Abort tests for sampling of gamma inverse.') | |
+ X = GammaInverse("x", 1, 1) | |
+ with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed | |
+ assert next(sample(X)) in X.pspace.domain.set | |
+ | |
+ | |
+def test_lognormal_sampling(): | |
+ # Right now, only density function and sampling works | |
+ scipy = import_module('scipy') | |
+ if not scipy: | |
+ skip('Scipy is not installed. Abort tests') | |
+ with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed | |
+ for i in range(3): | |
+ X = LogNormal('x', i, 1) | |
+ assert next(sample(X)) in X.pspace.domain.set | |
+ | |
+ size = 5 | |
+ with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed | |
+ samps = next(sample(X, size=size)) | |
+ for samp in samps: | |
+ assert samp in X.pspace.domain.set | |
+ | |
+ | |
+def test_sampling_gaussian_inverse(): | |
+ scipy = import_module('scipy') | |
+ if not scipy: | |
+ skip('Scipy not installed. Abort tests for sampling of Gaussian inverse.') | |
+ X = GaussianInverse("x", 1, 1) | |
+ with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed | |
+ assert next(sample(X, library='scipy')) in X.pspace.domain.set | |
+ | |
+ | |
+def test_prefab_sampling(): | |
+ scipy = import_module('scipy') | |
+ if not scipy: | |
+ skip('Scipy is not installed. Abort tests') | |
+ N = Normal('X', 0, 1) | |
+ L = LogNormal('L', 0, 1) | |
+ E = Exponential('Ex', 1) | |
+ P = Pareto('P', 1, 3) | |
+ W = Weibull('W', 1, 1) | |
+ U = Uniform('U', 0, 1) | |
+ B = Beta('B', 2, 5) | |
+ G = Gamma('G', 1, 3) | |
+ | |
+ variables = [N, L, E, P, W, U, B, G] | |
+ niter = 10 | |
+ size = 5 | |
+ with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed | |
+ for var in variables: | |
+ for _ in range(niter): | |
+ assert next(sample(var)) in var.pspace.domain.set | |
+ samps = next(sample(var, size=size)) | |
+ for samp in samps: | |
+ assert samp in var.pspace.domain.set | |
+ | |
+ | |
+def test_sample_continuous(): | |
+ z = Symbol('z') | |
+ Z = ContinuousRV(z, exp(-z), set=Interval(0, oo)) | |
+ assert density(Z)(-1) == 0 | |
+ | |
+ scipy = import_module('scipy') | |
+ if not scipy: | |
+ skip('Scipy is not installed. Abort tests') | |
+ with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed | |
+ assert next(sample(Z)) in Z.pspace.domain.set | |
+ sym, val = list(Z.pspace.sample().items())[0] | |
+ assert sym == Z and val in Interval(0, oo) | |
+ | |
+ libraries = ['scipy', 'numpy', 'pymc3'] | |
+ for lib in libraries: | |
+ try: | |
+ imported_lib = import_module(lib) | |
+ if imported_lib: | |
+ s0, s1, s2 = [], [], [] | |
+ s0 = list(sample(Z, numsamples=10, library=lib, seed=0)) | |
+ s1 = list(sample(Z, numsamples=10, library=lib, seed=0)) | |
+ s2 = list(sample(Z, numsamples=10, library=lib, seed=1)) | |
+ assert s0 == s1 | |
+ assert s1 != s2 | |
+ except NotImplementedError: | |
+ continue | |
diff --git a/sympy/stats/sampling/tests/test_sample_discrete_rv.py b/sympy/stats/sampling/tests/test_sample_discrete_rv.py | |
index ed6c64af78..f20b1633e2 100644 | |
--- a/sympy/stats/sampling/tests/test_sample_discrete_rv.py | |
+++ b/sympy/stats/sampling/tests/test_sample_discrete_rv.py | |
@@ -1,7 +1,7 @@ | |
from sympy import S, Symbol | |
from sympy.external import import_module | |
from sympy.stats import Geometric, Poisson, Zeta, sample, Skellam, DiscreteRV, Logarithmic, NegativeBinomial, YuleSimon | |
-from sympy.testing.pytest import skip, ignore_warnings, raises | |
+from sympy.testing.pytest import skip, ignore_warnings, raises, slow | |
def test_sample_numpy(): | |
@@ -77,3 +77,29 @@ def test_sample_pymc3(): | |
assert sam in X.pspace.domain.set | |
raises(NotImplementedError, | |
lambda: next(sample(Skellam('S', 1, 1), library='pymc3'))) | |
+ | |
+@slow | |
+def test_sample_discrete(): | |
+ X = Geometric('X', S.Half) | |
+ scipy = import_module('scipy') | |
+ if not scipy: | |
+ skip('Scipy not installed. Abort tests') | |
+ with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed | |
+ assert next(sample(X)) in X.pspace.domain.set | |
+ samps = next(sample(X, size=2)) # This takes long time if ran without scipy | |
+ for samp in samps: | |
+ assert samp in X.pspace.domain.set | |
+ | |
+ libraries = ['scipy', 'numpy', 'pymc3'] | |
+ for lib in libraries: | |
+ try: | |
+ imported_lib = import_module(lib) | |
+ if imported_lib: | |
+ s0, s1, s2 = [], [], [] | |
+ s0 = list(sample(X, numsamples=10, library=lib, seed=0)) | |
+ s1 = list(sample(X, numsamples=10, library=lib, seed=0)) | |
+ s2 = list(sample(X, numsamples=10, library=lib, seed=1)) | |
+ assert s0 == s1 | |
+ assert s1 != s2 | |
+ except NotImplementedError: | |
+ continue | |
diff --git a/sympy/stats/sampling/tests/test_sample_finite_rv.py b/sympy/stats/sampling/tests/test_sample_finite_rv.py | |
index 41019cbbcb..c2e6aba6fe 100644 | |
--- a/sympy/stats/sampling/tests/test_sample_finite_rv.py | |
+++ b/sympy/stats/sampling/tests/test_sample_finite_rv.py | |
@@ -4,6 +4,13 @@ | |
Rademacher | |
from sympy.testing.pytest import skip, ignore_warnings, raises | |
+def test_given_sample(): | |
+ X = Die('X', 6) | |
+ scipy = import_module('scipy') | |
+ if not scipy: | |
+ skip('Scipy is not installed. Abort tests') | |
+ with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed | |
+ assert next(sample(X, X > 5)) == 6 | |
def test_sample_numpy(): | |
distribs_numpy = [ | |
diff --git a/sympy/stats/tests/test_continuous_rv.py b/sympy/stats/tests/test_continuous_rv.py | |
index f9d79d1340..e26edb766d 100644 | |
--- a/sympy/stats/tests/test_continuous_rv.py | |
+++ b/sympy/stats/tests/test_continuous_rv.py | |
@@ -11,7 +11,7 @@ | |
from sympy.sets.sets import Intersection, FiniteSet | |
from sympy.stats import (P, E, where, density, variance, covariance, skewness, kurtosis, median, | |
given, pspace, cdf, characteristic_function, moment_generating_function, | |
- ContinuousRV, sample, Arcsin, Benini, Beta, BetaNoncentral, BetaPrime, | |
+ ContinuousRV, Arcsin, Benini, Beta, BetaNoncentral, BetaPrime, | |
Cauchy, Chi, ChiSquared, ChiNoncentral, Dagum, Erlang, ExGaussian, | |
Exponential, ExponentialPower, FDistribution, FisherZ, Frechet, Gamma, | |
GammaInverse, Gompertz, Gumbel, Kumaraswamy, Laplace, Levy, Logistic, LogCauchy, | |
@@ -321,33 +321,6 @@ def test_moment_generating_function(): | |
besseli(0, 1) | |
-def test_sample_continuous(): | |
- Z = ContinuousRV(z, exp(-z), set=Interval(0, oo)) | |
- assert density(Z)(-1) == 0 | |
- | |
- scipy = import_module('scipy') | |
- if not scipy: | |
- skip('Scipy is not installed. Abort tests') | |
- with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed | |
- assert next(sample(Z)) in Z.pspace.domain.set | |
- sym, val = list(Z.pspace.sample().items())[0] | |
- assert sym == Z and val in Interval(0, oo) | |
- | |
- libraries = ['scipy', 'numpy', 'pymc3'] | |
- for lib in libraries: | |
- try: | |
- imported_lib = import_module(lib) | |
- if imported_lib: | |
- s0, s1, s2 = [], [], [] | |
- s0 = list(sample(Z, numsamples=10, library=lib, seed=0)) | |
- s1 = list(sample(Z, numsamples=10, library=lib, seed=0)) | |
- s2 = list(sample(Z, numsamples=10, library=lib, seed=1)) | |
- assert s0 == s1 | |
- assert s1 != s2 | |
- except NotImplementedError: | |
- continue | |
- | |
- | |
def test_ContinuousRV(): | |
pdf = sqrt(2)*exp(-x**2/2)/(2*sqrt(pi)) # Normal distribution | |
# X and Y should be equivalent | |
@@ -777,14 +750,6 @@ def test_gamma_inverse(): | |
* besselk(a, 2*sqrt(b)*sqrt(-I*x))/gamma(a) | |
raises(NotImplementedError, lambda: moment_generating_function(X)) | |
-def test_sampling_gamma_inverse(): | |
- scipy = import_module('scipy') | |
- if not scipy: | |
- skip('Scipy not installed. Abort tests for sampling of gamma inverse.') | |
- X = GammaInverse("x", 1, 1) | |
- with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed | |
- assert next(sample(X)) in X.pspace.domain.set | |
- | |
def test_gompertz(): | |
b = Symbol("b", positive=True) | |
eta = Symbol("eta", positive=True) | |
@@ -924,20 +889,6 @@ def test_lognormal(): | |
#assert E(X) == exp(mean+std**2/2) | |
#assert variance(X) == (exp(std**2)-1) * exp(2*mean + std**2) | |
- # Right now, only density function and sampling works | |
- scipy = import_module('scipy') | |
- if not scipy: | |
- skip('Scipy is not installed. Abort tests') | |
- with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed | |
- for i in range(3): | |
- X = LogNormal('x', i, 1) | |
- assert next(sample(X)) in X.pspace.domain.set | |
- | |
- size = 5 | |
- with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed | |
- samps = next(sample(X, size=size)) | |
- for samp in samps: | |
- assert samp in X.pspace.domain.set | |
# The sympy integrator can't do this too well | |
#assert E(X) == | |
raises(NotImplementedError, lambda: moment_generating_function(X)) | |
@@ -1052,14 +1003,6 @@ def test_gaussian_inverse(): | |
b = symbols('b', nonpositive=True) | |
raises(ValueError, lambda: GaussianInverse('x', a, b)) | |
-def test_sampling_gaussian_inverse(): | |
- scipy = import_module('scipy') | |
- if not scipy: | |
- skip('Scipy not installed. Abort tests for sampling of Gaussian inverse.') | |
- X = GaussianInverse("x", 1, 1) | |
- with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed | |
- assert next(sample(X, library='scipy')) in X.pspace.domain.set | |
- | |
def test_pareto(): | |
xm, beta = symbols('xm beta', positive=True) | |
alpha = beta + 5 | |
@@ -1343,30 +1286,6 @@ def test_wignersemicircle(): | |
Piecewise((2*besselj(1, R*x)/(R*x), Ne(x, 0)), (1, True)) | |
-def test_prefab_sampling(): | |
- scipy = import_module('scipy') | |
- if not scipy: | |
- skip('Scipy is not installed. Abort tests') | |
- N = Normal('X', 0, 1) | |
- L = LogNormal('L', 0, 1) | |
- E = Exponential('Ex', 1) | |
- P = Pareto('P', 1, 3) | |
- W = Weibull('W', 1, 1) | |
- U = Uniform('U', 0, 1) | |
- B = Beta('B', 2, 5) | |
- G = Gamma('G', 1, 3) | |
- | |
- variables = [N, L, E, P, W, U, B, G] | |
- niter = 10 | |
- size = 5 | |
- with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed | |
- for var in variables: | |
- for _ in range(niter): | |
- assert next(sample(var)) in var.pspace.domain.set | |
- samps = next(sample(var, size=size)) | |
- for samp in samps: | |
- assert samp in var.pspace.domain.set | |
- | |
def test_input_value_assertions(): | |
a, b = symbols('a b') | |
p, q = symbols('p q', positive=True) | |
diff --git a/sympy/stats/tests/test_discrete_rv.py b/sympy/stats/tests/test_discrete_rv.py | |
index 56304da869..e4a693ec20 100644 | |
--- a/sympy/stats/tests/test_discrete_rv.py | |
+++ b/sympy/stats/tests/test_discrete_rv.py | |
@@ -12,7 +12,6 @@ | |
FlorySchulz, Poisson, Geometric, Hermite, Logarithmic, | |
NegativeBinomial, Skellam, YuleSimon, Zeta, | |
DiscreteRV) | |
-from sympy.stats.rv import sample | |
from sympy.testing.pytest import slow, nocache_fail, raises, skip, ignore_warnings | |
from sympy.external import import_module | |
from sympy.stats.symbolic_probability import Expectation | |
@@ -152,32 +151,6 @@ def test_zeta(): | |
zeta(s) * zeta(s-2) - zeta(s-1)**2) / zeta(s)**2 | |
-@slow | |
-def test_sample_discrete(): | |
- X = Geometric('X', S.Half) | |
- scipy = import_module('scipy') | |
- if not scipy: | |
- skip('Scipy not installed. Abort tests') | |
- with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed | |
- assert next(sample(X)) in X.pspace.domain.set | |
- samps = next(sample(X, size=2)) # This takes long time if ran without scipy | |
- for samp in samps: | |
- assert samp in X.pspace.domain.set | |
- | |
- libraries = ['scipy', 'numpy', 'pymc3'] | |
- for lib in libraries: | |
- try: | |
- imported_lib = import_module(lib) | |
- if imported_lib: | |
- s0, s1, s2 = [], [], [] | |
- s0 = list(sample(X, numsamples=10, library=lib, seed=0)) | |
- s1 = list(sample(X, numsamples=10, library=lib, seed=0)) | |
- s2 = list(sample(X, numsamples=10, library=lib, seed=1)) | |
- assert s0 == s1 | |
- assert s1 != s2 | |
- except NotImplementedError: | |
- continue | |
- | |
def test_discrete_probability(): | |
X = Geometric('X', Rational(1, 5)) | |
Y = Poisson('Y', 4) | |
diff --git a/sympy/stats/tests/test_finite_rv.py b/sympy/stats/tests/test_finite_rv.py | |
index 08cdc38845..644eedde14 100644 | |
--- a/sympy/stats/tests/test_finite_rv.py | |
+++ b/sympy/stats/tests/test_finite_rv.py | |
@@ -5,7 +5,7 @@ | |
from sympy.matrices import Matrix | |
from sympy.stats import (DiscreteUniform, Die, Bernoulli, Coin, Binomial, BetaBinomial, | |
Hypergeometric, Rademacher, IdealSoliton, RobustSoliton, P, E, variance, | |
- covariance, skewness, sample, density, where, FiniteRV, pspace, cdf, | |
+ covariance, skewness, density, where, FiniteRV, pspace, cdf, | |
correlation, moment, cmoment, smoment, characteristic_function, | |
moment_generating_function, quantile, kurtosis, median, coskewness) | |
from sympy.stats.frv_types import DieDistribution, BinomialDistribution, \ | |
@@ -146,11 +146,6 @@ def test_given(): | |
X = Die('X', 6) | |
assert density(X, X > 5) == {S(6): S.One} | |
assert where(X > 2, X > 5).as_boolean() == Eq(X.symbol, 6) | |
- scipy = import_module('scipy') | |
- if not scipy: | |
- skip('Scipy is not installed. Abort tests') | |
- with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed | |
- assert next(sample(X, X > 5)) == 6 | |
def test_domains(): |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment