Skip to content

Instantly share code, notes, and snippets.

@czgdp1807
Last active January 13, 2021 05:27
Show Gist options
  • Save czgdp1807/879223647496d9cef0b246272731c3db to your computer and use it in GitHub Desktop.
Save czgdp1807/879223647496d9cef0b246272731c3db to your computer and use it in GitHub Desktop.
diff --git a/sympy/stats/sampling/tests/test_sample_continuous_rv.py b/sympy/stats/sampling/tests/test_sample_continuous_rv.py
index 7f13b4061e..46993145cb 100644
--- a/sympy/stats/sampling/tests/test_sample_continuous_rv.py
+++ b/sympy/stats/sampling/tests/test_sample_continuous_rv.py
@@ -1,6 +1,7 @@
+from sympy import exp, Interval, oo, Symbol
from sympy.external import import_module
from sympy.stats import Beta, Chi, Normal, Gamma, Exponential, LogNormal, Pareto, ChiSquared, Uniform, sample, \
- BetaPrime, Cauchy, GammaInverse, GaussianInverse, StudentT
+ BetaPrime, Cauchy, GammaInverse, GaussianInverse, StudentT, Weibull, density, ContinuousRV
from sympy.testing.pytest import skip, ignore_warnings, raises
@@ -92,3 +93,91 @@ def test_sample_pymc3():
assert sam in X.pspace.domain.set
raises(NotImplementedError,
lambda: next(sample(Chi("C", 1), library='pymc3')))
+
+
+def test_sampling_gamma_inverse():
+ scipy = import_module('scipy')
+ if not scipy:
+ skip('Scipy not installed. Abort tests for sampling of gamma inverse.')
+ X = GammaInverse("x", 1, 1)
+ with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed
+ assert next(sample(X)) in X.pspace.domain.set
+
+
+def test_lognormal_sampling():
+ # Right now, only density function and sampling works
+ scipy = import_module('scipy')
+ if not scipy:
+ skip('Scipy is not installed. Abort tests')
+ with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed
+ for i in range(3):
+ X = LogNormal('x', i, 1)
+ assert next(sample(X)) in X.pspace.domain.set
+
+ size = 5
+ with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed
+ samps = next(sample(X, size=size))
+ for samp in samps:
+ assert samp in X.pspace.domain.set
+
+
+def test_sampling_gaussian_inverse():
+ scipy = import_module('scipy')
+ if not scipy:
+ skip('Scipy not installed. Abort tests for sampling of Gaussian inverse.')
+ X = GaussianInverse("x", 1, 1)
+ with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed
+ assert next(sample(X, library='scipy')) in X.pspace.domain.set
+
+
+def test_prefab_sampling():
+ scipy = import_module('scipy')
+ if not scipy:
+ skip('Scipy is not installed. Abort tests')
+ N = Normal('X', 0, 1)
+ L = LogNormal('L', 0, 1)
+ E = Exponential('Ex', 1)
+ P = Pareto('P', 1, 3)
+ W = Weibull('W', 1, 1)
+ U = Uniform('U', 0, 1)
+ B = Beta('B', 2, 5)
+ G = Gamma('G', 1, 3)
+
+ variables = [N, L, E, P, W, U, B, G]
+ niter = 10
+ size = 5
+ with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed
+ for var in variables:
+ for _ in range(niter):
+ assert next(sample(var)) in var.pspace.domain.set
+ samps = next(sample(var, size=size))
+ for samp in samps:
+ assert samp in var.pspace.domain.set
+
+
+def test_sample_continuous():
+ z = Symbol('z')
+ Z = ContinuousRV(z, exp(-z), set=Interval(0, oo))
+ assert density(Z)(-1) == 0
+
+ scipy = import_module('scipy')
+ if not scipy:
+ skip('Scipy is not installed. Abort tests')
+ with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed
+ assert next(sample(Z)) in Z.pspace.domain.set
+ sym, val = list(Z.pspace.sample().items())[0]
+ assert sym == Z and val in Interval(0, oo)
+
+ libraries = ['scipy', 'numpy', 'pymc3']
+ for lib in libraries:
+ try:
+ imported_lib = import_module(lib)
+ if imported_lib:
+ s0, s1, s2 = [], [], []
+ s0 = list(sample(Z, numsamples=10, library=lib, seed=0))
+ s1 = list(sample(Z, numsamples=10, library=lib, seed=0))
+ s2 = list(sample(Z, numsamples=10, library=lib, seed=1))
+ assert s0 == s1
+ assert s1 != s2
+ except NotImplementedError:
+ continue
diff --git a/sympy/stats/sampling/tests/test_sample_discrete_rv.py b/sympy/stats/sampling/tests/test_sample_discrete_rv.py
index ed6c64af78..f20b1633e2 100644
--- a/sympy/stats/sampling/tests/test_sample_discrete_rv.py
+++ b/sympy/stats/sampling/tests/test_sample_discrete_rv.py
@@ -1,7 +1,7 @@
from sympy import S, Symbol
from sympy.external import import_module
from sympy.stats import Geometric, Poisson, Zeta, sample, Skellam, DiscreteRV, Logarithmic, NegativeBinomial, YuleSimon
-from sympy.testing.pytest import skip, ignore_warnings, raises
+from sympy.testing.pytest import skip, ignore_warnings, raises, slow
def test_sample_numpy():
@@ -77,3 +77,29 @@ def test_sample_pymc3():
assert sam in X.pspace.domain.set
raises(NotImplementedError,
lambda: next(sample(Skellam('S', 1, 1), library='pymc3')))
+
+@slow
+def test_sample_discrete():
+ X = Geometric('X', S.Half)
+ scipy = import_module('scipy')
+ if not scipy:
+ skip('Scipy not installed. Abort tests')
+ with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed
+ assert next(sample(X)) in X.pspace.domain.set
+ samps = next(sample(X, size=2)) # This takes long time if ran without scipy
+ for samp in samps:
+ assert samp in X.pspace.domain.set
+
+ libraries = ['scipy', 'numpy', 'pymc3']
+ for lib in libraries:
+ try:
+ imported_lib = import_module(lib)
+ if imported_lib:
+ s0, s1, s2 = [], [], []
+ s0 = list(sample(X, numsamples=10, library=lib, seed=0))
+ s1 = list(sample(X, numsamples=10, library=lib, seed=0))
+ s2 = list(sample(X, numsamples=10, library=lib, seed=1))
+ assert s0 == s1
+ assert s1 != s2
+ except NotImplementedError:
+ continue
diff --git a/sympy/stats/sampling/tests/test_sample_finite_rv.py b/sympy/stats/sampling/tests/test_sample_finite_rv.py
index 41019cbbcb..c2e6aba6fe 100644
--- a/sympy/stats/sampling/tests/test_sample_finite_rv.py
+++ b/sympy/stats/sampling/tests/test_sample_finite_rv.py
@@ -4,6 +4,13 @@
Rademacher
from sympy.testing.pytest import skip, ignore_warnings, raises
+def test_given_sample():
+ X = Die('X', 6)
+ scipy = import_module('scipy')
+ if not scipy:
+ skip('Scipy is not installed. Abort tests')
+ with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed
+ assert next(sample(X, X > 5)) == 6
def test_sample_numpy():
distribs_numpy = [
diff --git a/sympy/stats/tests/test_continuous_rv.py b/sympy/stats/tests/test_continuous_rv.py
index f9d79d1340..e26edb766d 100644
--- a/sympy/stats/tests/test_continuous_rv.py
+++ b/sympy/stats/tests/test_continuous_rv.py
@@ -11,7 +11,7 @@
from sympy.sets.sets import Intersection, FiniteSet
from sympy.stats import (P, E, where, density, variance, covariance, skewness, kurtosis, median,
given, pspace, cdf, characteristic_function, moment_generating_function,
- ContinuousRV, sample, Arcsin, Benini, Beta, BetaNoncentral, BetaPrime,
+ ContinuousRV, Arcsin, Benini, Beta, BetaNoncentral, BetaPrime,
Cauchy, Chi, ChiSquared, ChiNoncentral, Dagum, Erlang, ExGaussian,
Exponential, ExponentialPower, FDistribution, FisherZ, Frechet, Gamma,
GammaInverse, Gompertz, Gumbel, Kumaraswamy, Laplace, Levy, Logistic, LogCauchy,
@@ -321,33 +321,6 @@ def test_moment_generating_function():
besseli(0, 1)
-def test_sample_continuous():
- Z = ContinuousRV(z, exp(-z), set=Interval(0, oo))
- assert density(Z)(-1) == 0
-
- scipy = import_module('scipy')
- if not scipy:
- skip('Scipy is not installed. Abort tests')
- with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed
- assert next(sample(Z)) in Z.pspace.domain.set
- sym, val = list(Z.pspace.sample().items())[0]
- assert sym == Z and val in Interval(0, oo)
-
- libraries = ['scipy', 'numpy', 'pymc3']
- for lib in libraries:
- try:
- imported_lib = import_module(lib)
- if imported_lib:
- s0, s1, s2 = [], [], []
- s0 = list(sample(Z, numsamples=10, library=lib, seed=0))
- s1 = list(sample(Z, numsamples=10, library=lib, seed=0))
- s2 = list(sample(Z, numsamples=10, library=lib, seed=1))
- assert s0 == s1
- assert s1 != s2
- except NotImplementedError:
- continue
-
-
def test_ContinuousRV():
pdf = sqrt(2)*exp(-x**2/2)/(2*sqrt(pi)) # Normal distribution
# X and Y should be equivalent
@@ -777,14 +750,6 @@ def test_gamma_inverse():
* besselk(a, 2*sqrt(b)*sqrt(-I*x))/gamma(a)
raises(NotImplementedError, lambda: moment_generating_function(X))
-def test_sampling_gamma_inverse():
- scipy = import_module('scipy')
- if not scipy:
- skip('Scipy not installed. Abort tests for sampling of gamma inverse.')
- X = GammaInverse("x", 1, 1)
- with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed
- assert next(sample(X)) in X.pspace.domain.set
-
def test_gompertz():
b = Symbol("b", positive=True)
eta = Symbol("eta", positive=True)
@@ -924,20 +889,6 @@ def test_lognormal():
#assert E(X) == exp(mean+std**2/2)
#assert variance(X) == (exp(std**2)-1) * exp(2*mean + std**2)
- # Right now, only density function and sampling works
- scipy = import_module('scipy')
- if not scipy:
- skip('Scipy is not installed. Abort tests')
- with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed
- for i in range(3):
- X = LogNormal('x', i, 1)
- assert next(sample(X)) in X.pspace.domain.set
-
- size = 5
- with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed
- samps = next(sample(X, size=size))
- for samp in samps:
- assert samp in X.pspace.domain.set
# The sympy integrator can't do this too well
#assert E(X) ==
raises(NotImplementedError, lambda: moment_generating_function(X))
@@ -1052,14 +1003,6 @@ def test_gaussian_inverse():
b = symbols('b', nonpositive=True)
raises(ValueError, lambda: GaussianInverse('x', a, b))
-def test_sampling_gaussian_inverse():
- scipy = import_module('scipy')
- if not scipy:
- skip('Scipy not installed. Abort tests for sampling of Gaussian inverse.')
- X = GaussianInverse("x", 1, 1)
- with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed
- assert next(sample(X, library='scipy')) in X.pspace.domain.set
-
def test_pareto():
xm, beta = symbols('xm beta', positive=True)
alpha = beta + 5
@@ -1343,30 +1286,6 @@ def test_wignersemicircle():
Piecewise((2*besselj(1, R*x)/(R*x), Ne(x, 0)), (1, True))
-def test_prefab_sampling():
- scipy = import_module('scipy')
- if not scipy:
- skip('Scipy is not installed. Abort tests')
- N = Normal('X', 0, 1)
- L = LogNormal('L', 0, 1)
- E = Exponential('Ex', 1)
- P = Pareto('P', 1, 3)
- W = Weibull('W', 1, 1)
- U = Uniform('U', 0, 1)
- B = Beta('B', 2, 5)
- G = Gamma('G', 1, 3)
-
- variables = [N, L, E, P, W, U, B, G]
- niter = 10
- size = 5
- with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed
- for var in variables:
- for _ in range(niter):
- assert next(sample(var)) in var.pspace.domain.set
- samps = next(sample(var, size=size))
- for samp in samps:
- assert samp in var.pspace.domain.set
-
def test_input_value_assertions():
a, b = symbols('a b')
p, q = symbols('p q', positive=True)
diff --git a/sympy/stats/tests/test_discrete_rv.py b/sympy/stats/tests/test_discrete_rv.py
index 56304da869..e4a693ec20 100644
--- a/sympy/stats/tests/test_discrete_rv.py
+++ b/sympy/stats/tests/test_discrete_rv.py
@@ -12,7 +12,6 @@
FlorySchulz, Poisson, Geometric, Hermite, Logarithmic,
NegativeBinomial, Skellam, YuleSimon, Zeta,
DiscreteRV)
-from sympy.stats.rv import sample
from sympy.testing.pytest import slow, nocache_fail, raises, skip, ignore_warnings
from sympy.external import import_module
from sympy.stats.symbolic_probability import Expectation
@@ -152,32 +151,6 @@ def test_zeta():
zeta(s) * zeta(s-2) - zeta(s-1)**2) / zeta(s)**2
-@slow
-def test_sample_discrete():
- X = Geometric('X', S.Half)
- scipy = import_module('scipy')
- if not scipy:
- skip('Scipy not installed. Abort tests')
- with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed
- assert next(sample(X)) in X.pspace.domain.set
- samps = next(sample(X, size=2)) # This takes long time if ran without scipy
- for samp in samps:
- assert samp in X.pspace.domain.set
-
- libraries = ['scipy', 'numpy', 'pymc3']
- for lib in libraries:
- try:
- imported_lib = import_module(lib)
- if imported_lib:
- s0, s1, s2 = [], [], []
- s0 = list(sample(X, numsamples=10, library=lib, seed=0))
- s1 = list(sample(X, numsamples=10, library=lib, seed=0))
- s2 = list(sample(X, numsamples=10, library=lib, seed=1))
- assert s0 == s1
- assert s1 != s2
- except NotImplementedError:
- continue
-
def test_discrete_probability():
X = Geometric('X', Rational(1, 5))
Y = Poisson('Y', 4)
diff --git a/sympy/stats/tests/test_finite_rv.py b/sympy/stats/tests/test_finite_rv.py
index 08cdc38845..644eedde14 100644
--- a/sympy/stats/tests/test_finite_rv.py
+++ b/sympy/stats/tests/test_finite_rv.py
@@ -5,7 +5,7 @@
from sympy.matrices import Matrix
from sympy.stats import (DiscreteUniform, Die, Bernoulli, Coin, Binomial, BetaBinomial,
Hypergeometric, Rademacher, IdealSoliton, RobustSoliton, P, E, variance,
- covariance, skewness, sample, density, where, FiniteRV, pspace, cdf,
+ covariance, skewness, density, where, FiniteRV, pspace, cdf,
correlation, moment, cmoment, smoment, characteristic_function,
moment_generating_function, quantile, kurtosis, median, coskewness)
from sympy.stats.frv_types import DieDistribution, BinomialDistribution, \
@@ -146,11 +146,6 @@ def test_given():
X = Die('X', 6)
assert density(X, X > 5) == {S(6): S.One}
assert where(X > 2, X > 5).as_boolean() == Eq(X.symbol, 6)
- scipy = import_module('scipy')
- if not scipy:
- skip('Scipy is not installed. Abort tests')
- with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed
- assert next(sample(X, X > 5)) == 6
def test_domains():
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment