Skip to content

Commit 5370976

Browse files
committed
add warnings for sample
1 parent 14a5441 commit 5370976

File tree

5 files changed

+96
-89
lines changed

5 files changed

+96
-89
lines changed

sympy/stats/rv.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
from sympy.sets.sets import FiniteSet, ProductSet, Intersection
2727
from sympy.solvers.solveset import solveset
2828
from sympy.external import import_module
29-
from sympy.utilities.exceptions import SymPyDeprecationWarning
29+
from sympy.utilities.misc import filldedent
30+
import warnings
3031

3132

3233
x = Symbol('x')
@@ -1071,6 +1072,10 @@ def sample(expr, condition=None, size=(1,), library='scipy', numsamples=1,
10711072
iterator object containing the sample/samples of given expr
10721073
10731074
"""
1075+
message = ("The return type of sample has been changed to return an "
1076+
"iterator object since version 1.7. For more information see "
1077+
"https://github.com/sympy/sympy/issues/19061")
1078+
warnings.warn(filldedent(message))
10741079
return sample_iter(expr, condition, size=size, library=library,
10751080
numsamples=numsamples)
10761081

@@ -1201,24 +1206,12 @@ def return_generator():
12011206

12021207
def sample_iter_lambdify(expr, condition=None, size=(1,), numsamples=S.Infinity,
12031208
**kwargs):
1204-
SymPyDeprecationWarning(
1205-
feature='sample_iter_lambdify',
1206-
useinstead='sample_iter',
1207-
issue=19061,
1208-
deprecated_since_version=1.6,
1209-
).warn()
12101209

12111210
return sample_iter(expr, condition=condition, size=size, numsamples=numsamples,
12121211
**kwargs)
12131212

12141213
def sample_iter_subs(expr, condition=None, size=(1,), numsamples=S.Infinity,
12151214
**kwargs):
1216-
SymPyDeprecationWarning(
1217-
feature='sample_iter_subs',
1218-
useinstead='sample_iter',
1219-
issue=19061,
1220-
deprecated_since_version=1.6,
1221-
).warn()
12221215

12231216
return sample_iter(expr, condition=condition, size=size, numsamples=numsamples,
12241217
**kwargs)

sympy/stats/tests/test_continuous_rv.py

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from sympy.stats.joint_rv_types import MultivariateLaplaceDistribution, MultivariateNormalDistribution
2525
from sympy.stats.crv import SingleContinuousPSpace, SingleContinuousDomain
2626
from sympy.stats.joint_rv import JointPSpace
27-
from sympy.testing.pytest import raises, XFAIL, slow, skip
27+
from sympy.testing.pytest import raises, XFAIL, slow, skip, ignore_warnings
2828
from sympy.testing.randtest import verify_numerically as tn
2929

3030
oo = S.Infinity
@@ -326,7 +326,8 @@ def test_sample_continuous():
326326
scipy = import_module('scipy')
327327
if not scipy:
328328
skip('Scipy is not installed. Abort tests')
329-
assert next(sample(Z))[0] in Z.pspace.domain.set
329+
with ignore_warnings(UserWarning):
330+
assert next(sample(Z))[0] in Z.pspace.domain.set
330331
sym, val = list(Z.pspace.sample().items())[0]
331332
assert sym == Z and val[0] in Interval(0, oo)
332333

@@ -735,7 +736,8 @@ def test_sampling_gamma_inverse():
735736
if not scipy:
736737
skip('Scipy not installed. Abort tests for sampling of gamma inverse.')
737738
X = GammaInverse("x", 1, 1)
738-
assert next(sample(X))[0] in X.pspace.domain.set
739+
with ignore_warnings(UserWarning):
740+
assert next(sample(X))[0] in X.pspace.domain.set
739741

740742
def test_gompertz():
741743
b = Symbol("b", positive=True)
@@ -861,14 +863,16 @@ def test_lognormal():
861863
scipy = import_module('scipy')
862864
if not scipy:
863865
skip('Scipy is not installed. Abort tests')
864-
for i in range(3):
865-
X = LogNormal('x', i, 1)
866-
assert next(sample(X))[0] in X.pspace.domain.set
866+
with ignore_warnings(UserWarning):
867+
for i in range(3):
868+
X = LogNormal('x', i, 1)
869+
assert next(sample(X))[0] in X.pspace.domain.set
867870

868871
size = 5
869-
samps = next(sample(X, size=size))
870-
for samp in samps:
871-
assert samp in X.pspace.domain.set
872+
with ignore_warnings(UserWarning):
873+
samps = next(sample(X, size=size))
874+
for samp in samps:
875+
assert samp in X.pspace.domain.set
872876
# The sympy integrator can't do this too well
873877
#assert E(X) ==
874878
raises(NotImplementedError, lambda: moment_generating_function(X))
@@ -971,7 +975,8 @@ def test_sampling_gaussian_inverse():
971975
if not scipy:
972976
skip('Scipy not installed. Abort tests for sampling of Gaussian inverse.')
973977
X = GaussianInverse("x", 1, 1)
974-
assert next(sample(X, library='scipy'))[0] in X.pspace.domain.set
978+
with ignore_warnings(UserWarning):
979+
assert next(sample(X, library='scipy'))[0] in X.pspace.domain.set
975980

976981
def test_pareto():
977982
xm, beta = symbols('xm beta', positive=True)
@@ -1257,6 +1262,9 @@ def test_wignersemicircle():
12571262

12581263

12591264
def test_prefab_sampling():
1265+
scipy = import_module('scipy')
1266+
if not scipy:
1267+
skip('Scipy is not installed. Abort tests')
12601268
N = Normal('X', 0, 1)
12611269
L = LogNormal('L', 0, 1)
12621270
E = Exponential('Ex', 1)
@@ -1269,15 +1277,13 @@ def test_prefab_sampling():
12691277
variables = [N, L, E, P, W, U, B, G]
12701278
niter = 10
12711279
size = 5
1272-
scipy = import_module('scipy')
1273-
if not scipy:
1274-
skip('Scipy is not installed. Abort tests')
1275-
for var in variables:
1276-
for i in range(niter):
1277-
assert next(sample(var))[0] in var.pspace.domain.set
1278-
samps = next(sample(var, size=size))
1279-
for samp in samps:
1280-
assert samp in var.pspace.domain.set
1280+
with ignore_warnings(UserWarning):
1281+
for var in variables:
1282+
for i in range(niter):
1283+
assert next(sample(var))[0] in var.pspace.domain.set
1284+
samps = next(sample(var, size=size))
1285+
for samp in samps:
1286+
assert samp in var.pspace.domain.set
12811287

12821288
def test_input_value_assertions():
12831289
a, b = symbols('a b')
@@ -1539,10 +1545,11 @@ def test_sample_numpy():
15391545
if not numpy:
15401546
skip('Numpy is not installed. Abort tests for _sample_numpy.')
15411547
else:
1542-
for X in distribs_numpy:
1543-
samps = next(sample(X, size=size, library='numpy'))
1544-
for sam in samps:
1545-
assert sam in X.pspace.domain.set
1548+
with ignore_warnings(UserWarning):
1549+
for X in distribs_numpy:
1550+
samps = next(sample(X, size=size, library='numpy'))
1551+
for sam in samps:
1552+
assert sam in X.pspace.domain.set
15461553

15471554

15481555
def test_sample_scipy():
@@ -1568,16 +1575,17 @@ def test_sample_scipy():
15681575
if not scipy:
15691576
skip('Scipy is not installed. Abort tests for _sample_scipy.')
15701577
else:
1571-
g_sample = list(sample(Gamma("G", 2, 7), size=size, numsamples=numsamples))
1572-
assert len(g_sample) == numsamples
1573-
for X in distribs_scipy:
1574-
samps = next(sample(X, size=size, library='scipy'))
1575-
samps2 = next(sample(X, size=(2, 2), library='scipy'))
1576-
for sam in samps:
1577-
assert sam in X.pspace.domain.set
1578-
for i in range(2):
1579-
for j in range(2):
1580-
assert samps2[i][j] in X.pspace.domain.set
1578+
with ignore_warnings(UserWarning):
1579+
g_sample = list(sample(Gamma("G", 2, 7), size=size, numsamples=numsamples))
1580+
assert len(g_sample) == numsamples
1581+
for X in distribs_scipy:
1582+
samps = next(sample(X, size=size, library='scipy'))
1583+
samps2 = next(sample(X, size=(2, 2), library='scipy'))
1584+
for sam in samps:
1585+
assert sam in X.pspace.domain.set
1586+
for i in range(2):
1587+
for j in range(2):
1588+
assert samps2[i][j] in X.pspace.domain.set
15811589

15821590
def test_sample_pymc3():
15831591
distribs_pymc3 = [
@@ -1597,10 +1605,11 @@ def test_sample_pymc3():
15971605
if not pymc3:
15981606
skip('PyMC3 is not installed. Abort tests for _sample_pymc3.')
15991607
else:
1600-
for X in distribs_pymc3:
1601-
samps = next(sample(X, size=size, library='pymc3'))
1602-
for sam in samps:
1603-
assert sam in X.pspace.domain.set
1608+
with ignore_warnings(UserWarning):
1609+
for X in distribs_pymc3:
1610+
samps = next(sample(X, size=size, library='pymc3'))
1611+
for sam in samps:
1612+
assert sam in X.pspace.domain.set
16041613

16051614

16061615
def test_issue_16318():

sympy/stats/tests/test_discrete_rv.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
NegativeBinomial, Skellam, YuleSimon, Zeta,
1414
DiscreteRV)
1515
from sympy.stats.rv import sample
16-
from sympy.testing.pytest import slow, nocache_fail, raises, skip
16+
from sympy.testing.pytest import slow, nocache_fail, raises, skip, ignore_warnings
1717
from sympy.external import import_module
1818

1919
x = Symbol('x')
@@ -140,10 +140,11 @@ def test_sample_discrete():
140140
scipy = import_module('scipy')
141141
if not scipy:
142142
skip('Scipy not installed. Abort tests')
143-
assert next(sample(X))[0] in X.pspace.domain.set
144-
samps = next(sample(X, size=2)) # This takes long time if ran without scipy
145-
for samp in samps:
146-
assert samp in X.pspace.domain.set
143+
with ignore_warnings(UserWarning):
144+
assert next(sample(X))[0] in X.pspace.domain.set
145+
samps = next(sample(X, size=2)) # This takes long time if ran without scipy
146+
for samp in samps:
147+
assert samp in X.pspace.domain.set
147148

148149
def test_discrete_probability():
149150
X = Geometric('X', Rational(1, 5))
@@ -306,23 +307,26 @@ def test_sampling_methods():
306307
if not numpy:
307308
skip('Numpy is not installed. Abort tests for _sample_numpy.')
308309
else:
309-
for X in distribs_numpy:
310-
samps = X.pspace.distribution._sample_numpy(size)
311-
for samp in samps:
312-
assert samp in X.pspace.domain.set
310+
with ignore_warnings(UserWarning):
311+
for X in distribs_numpy:
312+
samps = X.pspace.distribution._sample_numpy(size)
313+
for samp in samps:
314+
assert samp in X.pspace.domain.set
313315
scipy = import_module('scipy')
314316
if not scipy:
315317
skip('Scipy is not installed. Abort tests for _sample_scipy.')
316318
else:
317-
for X in distribs_scipy:
318-
samps = next(sample(X, size=size))
319-
for samp in samps:
320-
assert samp in X.pspace.domain.set
319+
with ignore_warnings(UserWarning):
320+
for X in distribs_scipy:
321+
samps = next(sample(X, size=size))
322+
for samp in samps:
323+
assert samp in X.pspace.domain.set
321324
pymc3 = import_module('pymc3')
322325
if not pymc3:
323326
skip('PyMC3 is not installed. Abort tests for _sample_pymc3.')
324327
else:
325-
for X in distribs_pymc3:
326-
samps = X.pspace.distribution._sample_pymc3(size)
327-
for samp in samps:
328-
assert samp in X.pspace.domain.set
328+
with ignore_warnings(UserWarning):
329+
for X in distribs_pymc3:
330+
samps = X.pspace.distribution._sample_pymc3(size)
331+
for samp in samps:
332+
assert samp in X.pspace.domain.set

sympy/stats/tests/test_finite_rv.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from sympy.stats.frv_types import DieDistribution, BinomialDistribution, \
1212
HypergeometricDistribution
1313
from sympy.stats.rv import Density
14-
from sympy.testing.pytest import raises, skip
14+
from sympy.testing.pytest import raises, skip, ignore_warnings
1515

1616

1717
def BayesTest(A, B):
@@ -149,7 +149,8 @@ def test_given():
149149
scipy = import_module('scipy')
150150
if not scipy:
151151
skip('Scipy is not installed. Abort tests')
152-
assert next(sample(X, X > 5)) == 6
152+
with ignore_warnings(UserWarning):
153+
assert next(sample(X, X > 5)) == 6
153154

154155

155156
def test_domains():
@@ -448,25 +449,27 @@ def test_sampling_methods():
448449
distribs_pymc3 = [BetaBinomial("B", 1, 1, 1)]
449450

450451
size = 5
451-
452-
for X in distribs_random:
453-
sam = X.pspace.distribution._sample_random(size)
454-
for i in range(size):
455-
assert sam[i] in X.pspace.domain.set
452+
with ignore_warnings(UserWarning):
453+
for X in distribs_random:
454+
sam = X.pspace.distribution._sample_random(size)
455+
for i in range(size):
456+
assert sam[i] in X.pspace.domain.set
456457

457458
scipy = import_module('scipy')
458459
if not scipy:
459460
skip('Scipy not installed. Abort tests for _sample_scipy.')
460461
else:
461-
for X in distribs_scipy:
462-
sam = X.pspace.distribution._sample_scipy(size)
463-
for i in range(size):
464-
assert sam[i] in X.pspace.domain.set
462+
with ignore_warnings(UserWarning):
463+
for X in distribs_scipy:
464+
sam = X.pspace.distribution._sample_scipy(size)
465+
for i in range(size):
466+
assert sam[i] in X.pspace.domain.set
465467
pymc3 = import_module('pymc3')
466468
if not pymc3:
467469
skip('PyMC3 not installed. Abort tests for _sample_pymc3.')
468470
else:
469-
for X in distribs_pymc3:
470-
sam = X.pspace.distribution._sample_pymc3(size)
471-
for i in range(size):
472-
assert sam[i] in X.pspace.domain.set
471+
with ignore_warnings(UserWarning):
472+
for X in distribs_pymc3:
473+
sam = X.pspace.distribution._sample_pymc3(size)
474+
for i in range(size):
475+
assert sam[i] in X.pspace.domain.set

sympy/stats/tests/test_rv.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
random_symbols, sample, Geometric, factorial_moment, Binomial, Hypergeometric,
88
DiscreteUniform, Poisson, characteristic_function, moment_generating_function)
99
from sympy.stats.rv import (IndependentProductPSpace, rs_swap, Density, NamedArgsMixin,
10-
RandomSymbol, sample_iter, PSpace, sample_iter_subs, sample_iter_lambdify)
11-
from sympy.testing.pytest import raises, skip, XFAIL, warns_deprecated_sympy
10+
RandomSymbol, sample_iter, PSpace)
11+
from sympy.testing.pytest import raises, skip, XFAIL, ignore_warnings
1212
from sympy.external import import_module
1313
from sympy.core.numbers import comp
1414
from sympy.stats.frv_types import BernoulliDistribution
@@ -104,9 +104,6 @@ def is_iterator(obj):
104104
return True
105105
else:
106106
return False
107-
with warns_deprecated_sympy():
108-
sample_iter_subs(expr)
109-
sample_iter_lambdify(expr)
110107
assert is_iterator(iterator)
111108
assert is_iterator(iterator2)
112109
assert is_iterator(iterator3)
@@ -191,8 +188,9 @@ def test_Sample():
191188
scipy = import_module('scipy')
192189
if not scipy:
193190
skip('Scipy is not installed. Abort tests')
194-
assert next(sample(X)) in [1, 2, 3, 4, 5, 6]
195-
assert next(sample(X + Y))[0].is_Float
191+
with ignore_warnings(UserWarning):
192+
assert next(sample(X)) in [1, 2, 3, 4, 5, 6]
193+
assert next(sample(X + Y))[0].is_Float
196194

197195
assert P(X + Y > 0, Y < 0, numsamples=10).is_number
198196
assert E(X + Y, numsamples=10).is_number

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy