rv.py 53 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798
  1. """
  2. Main Random Variables Module
  3. Defines abstract random variable type.
  4. Contains interfaces for probability space object (PSpace) as well as standard
  5. operators, P, E, sample, density, where, quantile
  6. See Also
  7. ========
  8. sympy.stats.crv
  9. sympy.stats.frv
  10. sympy.stats.rv_interface
  11. """
  12. from __future__ import annotations
  13. from functools import singledispatch
  14. from math import prod
  15. from sympy.core.add import Add
  16. from sympy.core.basic import Basic
  17. from sympy.core.containers import Tuple
  18. from sympy.core.expr import Expr
  19. from sympy.core.function import (Function, Lambda)
  20. from sympy.core.logic import fuzzy_and
  21. from sympy.core.mul import Mul
  22. from sympy.core.relational import (Eq, Ne)
  23. from sympy.core.singleton import S
  24. from sympy.core.symbol import (Dummy, Symbol)
  25. from sympy.core.sympify import sympify
  26. from sympy.functions.special.delta_functions import DiracDelta
  27. from sympy.functions.special.tensor_functions import KroneckerDelta
  28. from sympy.logic.boolalg import (And, Or)
  29. from sympy.matrices.expressions.matexpr import MatrixSymbol
  30. from sympy.tensor.indexed import Indexed
  31. from sympy.utilities.lambdify import lambdify
  32. from sympy.core.relational import Relational
  33. from sympy.core.sympify import _sympify
  34. from sympy.sets.sets import FiniteSet, ProductSet, Intersection
  35. from sympy.solvers.solveset import solveset
  36. from sympy.external import import_module
  37. from sympy.utilities.decorator import doctest_depends_on
  38. from sympy.utilities.exceptions import sympy_deprecation_warning
  39. from sympy.utilities.iterables import iterable
  40. __doctest_requires__ = {('sample',): ['scipy']}
  41. x = Symbol('x')
  42. @singledispatch
  43. def is_random(x):
  44. return False
  45. @is_random.register(Basic)
  46. def _(x):
  47. atoms = x.free_symbols
  48. return any(is_random(i) for i in atoms)
  49. class RandomDomain(Basic):
  50. """
  51. Represents a set of variables and the values which they can take.
  52. See Also
  53. ========
  54. sympy.stats.crv.ContinuousDomain
  55. sympy.stats.frv.FiniteDomain
  56. """
  57. is_ProductDomain = False
  58. is_Finite = False
  59. is_Continuous = False
  60. is_Discrete = False
  61. def __new__(cls, symbols, *args):
  62. symbols = FiniteSet(*symbols)
  63. return Basic.__new__(cls, symbols, *args)
  64. @property
  65. def symbols(self):
  66. return self.args[0]
  67. @property
  68. def set(self):
  69. return self.args[1]
  70. def __contains__(self, other):
  71. raise NotImplementedError()
  72. def compute_expectation(self, expr):
  73. raise NotImplementedError()
  74. class SingleDomain(RandomDomain):
  75. """
  76. A single variable and its domain.
  77. See Also
  78. ========
  79. sympy.stats.crv.SingleContinuousDomain
  80. sympy.stats.frv.SingleFiniteDomain
  81. """
  82. def __new__(cls, symbol, set):
  83. assert symbol.is_Symbol
  84. return Basic.__new__(cls, symbol, set)
  85. @property
  86. def symbol(self):
  87. return self.args[0]
  88. @property
  89. def symbols(self):
  90. return FiniteSet(self.symbol)
  91. def __contains__(self, other):
  92. if len(other) != 1:
  93. return False
  94. sym, val = tuple(other)[0]
  95. return self.symbol == sym and val in self.set
  96. class MatrixDomain(RandomDomain):
  97. """
  98. A Random Matrix variable and its domain.
  99. """
  100. def __new__(cls, symbol, set):
  101. symbol, set = _symbol_converter(symbol), _sympify(set)
  102. return Basic.__new__(cls, symbol, set)
  103. @property
  104. def symbol(self):
  105. return self.args[0]
  106. @property
  107. def symbols(self):
  108. return FiniteSet(self.symbol)
  109. class ConditionalDomain(RandomDomain):
  110. """
  111. A RandomDomain with an attached condition.
  112. See Also
  113. ========
  114. sympy.stats.crv.ConditionalContinuousDomain
  115. sympy.stats.frv.ConditionalFiniteDomain
  116. """
  117. def __new__(cls, fulldomain, condition):
  118. condition = condition.xreplace({rs: rs.symbol
  119. for rs in random_symbols(condition)})
  120. return Basic.__new__(cls, fulldomain, condition)
  121. @property
  122. def symbols(self):
  123. return self.fulldomain.symbols
  124. @property
  125. def fulldomain(self):
  126. return self.args[0]
  127. @property
  128. def condition(self):
  129. return self.args[1]
  130. @property
  131. def set(self):
  132. raise NotImplementedError("Set of Conditional Domain not Implemented")
  133. def as_boolean(self):
  134. return And(self.fulldomain.as_boolean(), self.condition)
  135. class PSpace(Basic):
  136. """
  137. A Probability Space.
  138. Explanation
  139. ===========
  140. Probability Spaces encode processes that equal different values
  141. probabilistically. These underly Random Symbols which occur in SymPy
  142. expressions and contain the mechanics to evaluate statistical statements.
  143. See Also
  144. ========
  145. sympy.stats.crv.ContinuousPSpace
  146. sympy.stats.frv.FinitePSpace
  147. """
  148. is_Finite: bool | None = None # Fails test if not set to None
  149. is_Continuous: bool | None = None # Fails test if not set to None
  150. is_Discrete: bool | None = None # Fails test if not set to None
  151. is_real: bool | None
  152. @property
  153. def domain(self):
  154. return self.args[0]
  155. @property
  156. def density(self):
  157. return self.args[1]
  158. @property
  159. def values(self):
  160. return frozenset(RandomSymbol(sym, self) for sym in self.symbols)
  161. @property
  162. def symbols(self):
  163. return self.domain.symbols
  164. def where(self, condition):
  165. raise NotImplementedError()
  166. def compute_density(self, expr):
  167. raise NotImplementedError()
  168. def sample(self, size=(), library='scipy', seed=None):
  169. raise NotImplementedError()
  170. def probability(self, condition):
  171. raise NotImplementedError()
  172. def compute_expectation(self, expr):
  173. raise NotImplementedError()
  174. class SinglePSpace(PSpace):
  175. """
  176. Represents the probabilities of a set of random events that can be
  177. attributed to a single variable/symbol.
  178. """
  179. def __new__(cls, s, distribution):
  180. s = _symbol_converter(s)
  181. return Basic.__new__(cls, s, distribution)
  182. @property
  183. def value(self):
  184. return RandomSymbol(self.symbol, self)
  185. @property
  186. def symbol(self):
  187. return self.args[0]
  188. @property
  189. def distribution(self):
  190. return self.args[1]
  191. @property
  192. def pdf(self):
  193. return self.distribution.pdf(self.symbol)
  194. class RandomSymbol(Expr):
  195. """
  196. Random Symbols represent ProbabilitySpaces in SymPy Expressions.
  197. In principle they can take on any value that their symbol can take on
  198. within the associated PSpace with probability determined by the PSpace
  199. Density.
  200. Explanation
  201. ===========
  202. Random Symbols contain pspace and symbol properties.
  203. The pspace property points to the represented Probability Space
  204. The symbol is a standard SymPy Symbol that is used in that probability space
  205. for example in defining a density.
  206. You can form normal SymPy expressions using RandomSymbols and operate on
  207. those expressions with the Functions
  208. E - Expectation of a random expression
  209. P - Probability of a condition
  210. density - Probability Density of an expression
  211. given - A new random expression (with new random symbols) given a condition
  212. An object of the RandomSymbol type should almost never be created by the
  213. user. They tend to be created instead by the PSpace class's value method.
  214. Traditionally a user does not even do this but instead calls one of the
  215. convenience functions Normal, Exponential, Coin, Die, FiniteRV, etc....
  216. """
  217. def __new__(cls, symbol, pspace=None):
  218. from sympy.stats.joint_rv import JointRandomSymbol
  219. if pspace is None:
  220. # Allow single arg, representing pspace == PSpace()
  221. pspace = PSpace()
  222. symbol = _symbol_converter(symbol)
  223. if not isinstance(pspace, PSpace):
  224. raise TypeError("pspace variable should be of type PSpace")
  225. if cls == JointRandomSymbol and isinstance(pspace, SinglePSpace):
  226. cls = RandomSymbol
  227. return Basic.__new__(cls, symbol, pspace)
  228. is_finite = True
  229. is_symbol = True
  230. is_Atom = True
  231. _diff_wrt = True
  232. pspace = property(lambda self: self.args[1])
  233. symbol = property(lambda self: self.args[0])
  234. name = property(lambda self: self.symbol.name)
  235. def _eval_is_positive(self):
  236. return self.symbol.is_positive
  237. def _eval_is_integer(self):
  238. return self.symbol.is_integer
  239. def _eval_is_real(self):
  240. return self.symbol.is_real or self.pspace.is_real
  241. @property
  242. def is_commutative(self):
  243. return self.symbol.is_commutative
  244. @property
  245. def free_symbols(self):
  246. return {self}
  247. class RandomIndexedSymbol(RandomSymbol):
  248. def __new__(cls, idx_obj, pspace=None):
  249. if pspace is None:
  250. # Allow single arg, representing pspace == PSpace()
  251. pspace = PSpace()
  252. if not isinstance(idx_obj, (Indexed, Function)):
  253. raise TypeError("An Function or Indexed object is expected not %s"%(idx_obj))
  254. return Basic.__new__(cls, idx_obj, pspace)
  255. symbol = property(lambda self: self.args[0])
  256. name = property(lambda self: str(self.args[0]))
  257. @property
  258. def key(self):
  259. if isinstance(self.symbol, Indexed):
  260. return self.symbol.args[1]
  261. elif isinstance(self.symbol, Function):
  262. return self.symbol.args[0]
  263. @property
  264. def free_symbols(self):
  265. if self.key.free_symbols:
  266. free_syms = self.key.free_symbols
  267. free_syms.add(self)
  268. return free_syms
  269. return {self}
  270. @property
  271. def pspace(self):
  272. return self.args[1]
  273. class RandomMatrixSymbol(RandomSymbol, MatrixSymbol): # type: ignore
  274. def __new__(cls, symbol, n, m, pspace=None):
  275. n, m = _sympify(n), _sympify(m)
  276. symbol = _symbol_converter(symbol)
  277. if pspace is None:
  278. # Allow single arg, representing pspace == PSpace()
  279. pspace = PSpace()
  280. return Basic.__new__(cls, symbol, n, m, pspace)
  281. symbol = property(lambda self: self.args[0])
  282. pspace = property(lambda self: self.args[3])
  283. class ProductPSpace(PSpace):
  284. """
  285. Abstract class for representing probability spaces with multiple random
  286. variables.
  287. See Also
  288. ========
  289. sympy.stats.rv.IndependentProductPSpace
  290. sympy.stats.joint_rv.JointPSpace
  291. """
  292. pass
  293. class IndependentProductPSpace(ProductPSpace):
  294. """
  295. A probability space resulting from the merger of two independent probability
  296. spaces.
  297. Often created using the function, pspace.
  298. """
  299. def __new__(cls, *spaces):
  300. rs_space_dict = {}
  301. for space in spaces:
  302. for value in space.values:
  303. rs_space_dict[value] = space
  304. symbols = FiniteSet(*[val.symbol for val in rs_space_dict.keys()])
  305. # Overlapping symbols
  306. from sympy.stats.joint_rv import MarginalDistribution
  307. from sympy.stats.compound_rv import CompoundDistribution
  308. if len(symbols) < sum(len(space.symbols) for space in spaces if not
  309. isinstance(space.distribution, (
  310. CompoundDistribution, MarginalDistribution))):
  311. raise ValueError("Overlapping Random Variables")
  312. if all(space.is_Finite for space in spaces):
  313. from sympy.stats.frv import ProductFinitePSpace
  314. cls = ProductFinitePSpace
  315. obj = Basic.__new__(cls, *FiniteSet(*spaces))
  316. return obj
  317. @property
  318. def pdf(self):
  319. p = Mul(*[space.pdf for space in self.spaces])
  320. return p.subs({rv: rv.symbol for rv in self.values})
  321. @property
  322. def rs_space_dict(self):
  323. d = {}
  324. for space in self.spaces:
  325. for value in space.values:
  326. d[value] = space
  327. return d
  328. @property
  329. def symbols(self):
  330. return FiniteSet(*[val.symbol for val in self.rs_space_dict.keys()])
  331. @property
  332. def spaces(self):
  333. return FiniteSet(*self.args)
  334. @property
  335. def values(self):
  336. return sumsets(space.values for space in self.spaces)
  337. def compute_expectation(self, expr, rvs=None, evaluate=False, **kwargs):
  338. rvs = rvs or self.values
  339. rvs = frozenset(rvs)
  340. for space in self.spaces:
  341. expr = space.compute_expectation(expr, rvs & space.values, evaluate=False, **kwargs)
  342. if evaluate and hasattr(expr, 'doit'):
  343. return expr.doit(**kwargs)
  344. return expr
  345. @property
  346. def domain(self):
  347. return ProductDomain(*[space.domain for space in self.spaces])
  348. @property
  349. def density(self):
  350. raise NotImplementedError("Density not available for ProductSpaces")
  351. def sample(self, size=(), library='scipy', seed=None):
  352. return {k: v for space in self.spaces
  353. for k, v in space.sample(size=size, library=library, seed=seed).items()}
  354. def probability(self, condition, **kwargs):
  355. cond_inv = False
  356. if isinstance(condition, Ne):
  357. condition = Eq(condition.args[0], condition.args[1])
  358. cond_inv = True
  359. elif isinstance(condition, And): # they are independent
  360. return Mul(*[self.probability(arg) for arg in condition.args])
  361. elif isinstance(condition, Or): # they are independent
  362. return Add(*[self.probability(arg) for arg in condition.args])
  363. expr = condition.lhs - condition.rhs
  364. rvs = random_symbols(expr)
  365. dens = self.compute_density(expr)
  366. if any(pspace(rv).is_Continuous for rv in rvs):
  367. from sympy.stats.crv import SingleContinuousPSpace
  368. from sympy.stats.crv_types import ContinuousDistributionHandmade
  369. if expr in self.values:
  370. # Marginalize all other random symbols out of the density
  371. randomsymbols = tuple(set(self.values) - frozenset([expr]))
  372. symbols = tuple(rs.symbol for rs in randomsymbols)
  373. pdf = self.domain.integrate(self.pdf, symbols, **kwargs)
  374. return Lambda(expr.symbol, pdf)
  375. dens = ContinuousDistributionHandmade(dens)
  376. z = Dummy('z', real=True)
  377. space = SingleContinuousPSpace(z, dens)
  378. result = space.probability(condition.__class__(space.value, 0))
  379. else:
  380. from sympy.stats.drv import SingleDiscretePSpace
  381. from sympy.stats.drv_types import DiscreteDistributionHandmade
  382. dens = DiscreteDistributionHandmade(dens)
  383. z = Dummy('z', integer=True)
  384. space = SingleDiscretePSpace(z, dens)
  385. result = space.probability(condition.__class__(space.value, 0))
  386. return result if not cond_inv else S.One - result
  387. def compute_density(self, expr, **kwargs):
  388. rvs = random_symbols(expr)
  389. if any(pspace(rv).is_Continuous for rv in rvs):
  390. z = Dummy('z', real=True)
  391. expr = self.compute_expectation(DiracDelta(expr - z),
  392. **kwargs)
  393. else:
  394. z = Dummy('z', integer=True)
  395. expr = self.compute_expectation(KroneckerDelta(expr, z),
  396. **kwargs)
  397. return Lambda(z, expr)
  398. def compute_cdf(self, expr, **kwargs):
  399. raise ValueError("CDF not well defined on multivariate expressions")
  400. def conditional_space(self, condition, normalize=True, **kwargs):
  401. rvs = random_symbols(condition)
  402. condition = condition.xreplace({rv: rv.symbol for rv in self.values})
  403. pspaces = [pspace(rv) for rv in rvs]
  404. if any(ps.is_Continuous for ps in pspaces):
  405. from sympy.stats.crv import (ConditionalContinuousDomain,
  406. ContinuousPSpace)
  407. space = ContinuousPSpace
  408. domain = ConditionalContinuousDomain(self.domain, condition)
  409. elif any(ps.is_Discrete for ps in pspaces):
  410. from sympy.stats.drv import (ConditionalDiscreteDomain,
  411. DiscretePSpace)
  412. space = DiscretePSpace
  413. domain = ConditionalDiscreteDomain(self.domain, condition)
  414. elif all(ps.is_Finite for ps in pspaces):
  415. from sympy.stats.frv import FinitePSpace
  416. return FinitePSpace.conditional_space(self, condition)
  417. if normalize:
  418. replacement = {rv: Dummy(str(rv)) for rv in self.symbols}
  419. norm = domain.compute_expectation(self.pdf, **kwargs)
  420. pdf = self.pdf / norm.xreplace(replacement)
  421. # XXX: Converting symbols from set to tuple. The order matters to
  422. # Lambda though so we shouldn't be starting with a set here...
  423. density = Lambda(tuple(domain.symbols), pdf)
  424. return space(domain, density)
  425. class ProductDomain(RandomDomain):
  426. """
  427. A domain resulting from the merger of two independent domains.
  428. See Also
  429. ========
  430. sympy.stats.crv.ProductContinuousDomain
  431. sympy.stats.frv.ProductFiniteDomain
  432. """
  433. is_ProductDomain = True
  434. def __new__(cls, *domains):
  435. # Flatten any product of products
  436. domains2 = []
  437. for domain in domains:
  438. if not domain.is_ProductDomain:
  439. domains2.append(domain)
  440. else:
  441. domains2.extend(domain.domains)
  442. domains2 = FiniteSet(*domains2)
  443. if all(domain.is_Finite for domain in domains2):
  444. from sympy.stats.frv import ProductFiniteDomain
  445. cls = ProductFiniteDomain
  446. if all(domain.is_Continuous for domain in domains2):
  447. from sympy.stats.crv import ProductContinuousDomain
  448. cls = ProductContinuousDomain
  449. if all(domain.is_Discrete for domain in domains2):
  450. from sympy.stats.drv import ProductDiscreteDomain
  451. cls = ProductDiscreteDomain
  452. return Basic.__new__(cls, *domains2)
  453. @property
  454. def sym_domain_dict(self):
  455. return {symbol: domain for domain in self.domains
  456. for symbol in domain.symbols}
  457. @property
  458. def symbols(self):
  459. return FiniteSet(*[sym for domain in self.domains
  460. for sym in domain.symbols])
  461. @property
  462. def domains(self):
  463. return self.args
  464. @property
  465. def set(self):
  466. return ProductSet(*(domain.set for domain in self.domains))
  467. def __contains__(self, other):
  468. # Split event into each subdomain
  469. for domain in self.domains:
  470. # Collect the parts of this event which associate to this domain
  471. elem = frozenset([item for item in other
  472. if sympify(domain.symbols.contains(item[0]))
  473. is S.true])
  474. # Test this sub-event
  475. if elem not in domain:
  476. return False
  477. # All subevents passed
  478. return True
  479. def as_boolean(self):
  480. return And(*[domain.as_boolean() for domain in self.domains])
  481. def random_symbols(expr):
  482. """
  483. Returns all RandomSymbols within a SymPy Expression.
  484. """
  485. atoms = getattr(expr, 'atoms', None)
  486. if atoms is not None:
  487. comp = lambda rv: rv.symbol.name
  488. l = list(atoms(RandomSymbol))
  489. return sorted(l, key=comp)
  490. else:
  491. return []
  492. def pspace(expr):
  493. """
  494. Returns the underlying Probability Space of a random expression.
  495. For internal use.
  496. Examples
  497. ========
  498. >>> from sympy.stats import pspace, Normal
  499. >>> X = Normal('X', 0, 1)
  500. >>> pspace(2*X + 1) == X.pspace
  501. True
  502. """
  503. expr = sympify(expr)
  504. if isinstance(expr, RandomSymbol) and expr.pspace is not None:
  505. return expr.pspace
  506. if expr.has(RandomMatrixSymbol):
  507. rm = list(expr.atoms(RandomMatrixSymbol))[0]
  508. return rm.pspace
  509. rvs = random_symbols(expr)
  510. if not rvs:
  511. raise ValueError("Expression containing Random Variable expected, not %s" % (expr))
  512. # If only one space present
  513. if all(rv.pspace == rvs[0].pspace for rv in rvs):
  514. return rvs[0].pspace
  515. from sympy.stats.compound_rv import CompoundPSpace
  516. from sympy.stats.stochastic_process import StochasticPSpace
  517. for rv in rvs:
  518. if isinstance(rv.pspace, (CompoundPSpace, StochasticPSpace)):
  519. return rv.pspace
  520. # Otherwise make a product space
  521. return IndependentProductPSpace(*[rv.pspace for rv in rvs])
  522. def sumsets(sets):
  523. """
  524. Union of sets
  525. """
  526. return frozenset().union(*sets)
  527. def rs_swap(a, b):
  528. """
  529. Build a dictionary to swap RandomSymbols based on their underlying symbol.
  530. i.e.
  531. if ``X = ('x', pspace1)``
  532. and ``Y = ('x', pspace2)``
  533. then ``X`` and ``Y`` match and the key, value pair
  534. ``{X:Y}`` will appear in the result
  535. Inputs: collections a and b of random variables which share common symbols
  536. Output: dict mapping RVs in a to RVs in b
  537. """
  538. d = {}
  539. for rsa in a:
  540. d[rsa] = [rsb for rsb in b if rsa.symbol == rsb.symbol][0]
  541. return d
  542. def given(expr, condition=None, **kwargs):
  543. r""" Conditional Random Expression.
  544. Explanation
  545. ===========
  546. From a random expression and a condition on that expression creates a new
  547. probability space from the condition and returns the same expression on that
  548. conditional probability space.
  549. Examples
  550. ========
  551. >>> from sympy.stats import given, density, Die
  552. >>> X = Die('X', 6)
  553. >>> Y = given(X, X > 3)
  554. >>> density(Y).dict
  555. {4: 1/3, 5: 1/3, 6: 1/3}
  556. Following convention, if the condition is a random symbol then that symbol
  557. is considered fixed.
  558. >>> from sympy.stats import Normal
  559. >>> from sympy import pprint
  560. >>> from sympy.abc import z
  561. >>> X = Normal('X', 0, 1)
  562. >>> Y = Normal('Y', 0, 1)
  563. >>> pprint(density(X + Y, Y)(z), use_unicode=False)
  564. 2
  565. -(-Y + z)
  566. -----------
  567. ___ 2
  568. \/ 2 *e
  569. ------------------
  570. ____
  571. 2*\/ pi
  572. """
  573. if not is_random(condition) or pspace_independent(expr, condition):
  574. return expr
  575. if isinstance(condition, RandomSymbol):
  576. condition = Eq(condition, condition.symbol)
  577. condsymbols = random_symbols(condition)
  578. if (isinstance(condition, Eq) and len(condsymbols) == 1 and
  579. not isinstance(pspace(expr).domain, ConditionalDomain)):
  580. rv = tuple(condsymbols)[0]
  581. results = solveset(condition, rv)
  582. if isinstance(results, Intersection) and S.Reals in results.args:
  583. results = list(results.args[1])
  584. sums = 0
  585. for res in results:
  586. temp = expr.subs(rv, res)
  587. if temp == True:
  588. return True
  589. if temp != False:
  590. # XXX: This seems nonsensical but preserves existing behaviour
  591. # after the change that Relational is no longer a subclass of
  592. # Expr. Here expr is sometimes Relational and sometimes Expr
  593. # but we are trying to add them with +=. This needs to be
  594. # fixed somehow.
  595. if sums == 0 and isinstance(expr, Relational):
  596. sums = expr.subs(rv, res)
  597. else:
  598. sums += expr.subs(rv, res)
  599. if sums == 0:
  600. return False
  601. return sums
  602. # Get full probability space of both the expression and the condition
  603. fullspace = pspace(Tuple(expr, condition))
  604. # Build new space given the condition
  605. space = fullspace.conditional_space(condition, **kwargs)
  606. # Dictionary to swap out RandomSymbols in expr with new RandomSymbols
  607. # That point to the new conditional space
  608. swapdict = rs_swap(fullspace.values, space.values)
  609. # Swap random variables in the expression
  610. expr = expr.xreplace(swapdict)
  611. return expr
  612. def expectation(expr, condition=None, numsamples=None, evaluate=True, **kwargs):
  613. """
  614. Returns the expected value of a random expression.
  615. Parameters
  616. ==========
  617. expr : Expr containing RandomSymbols
  618. The expression of which you want to compute the expectation value
  619. given : Expr containing RandomSymbols
  620. A conditional expression. E(X, X>0) is expectation of X given X > 0
  621. numsamples : int
  622. Enables sampling and approximates the expectation with this many samples
  623. evalf : Bool (defaults to True)
  624. If sampling return a number rather than a complex expression
  625. evaluate : Bool (defaults to True)
  626. In case of continuous systems return unevaluated integral
  627. Examples
  628. ========
  629. >>> from sympy.stats import E, Die
  630. >>> X = Die('X', 6)
  631. >>> E(X)
  632. 7/2
  633. >>> E(2*X + 1)
  634. 8
  635. >>> E(X, X > 3) # Expectation of X given that it is above 3
  636. 5
  637. """
  638. if not is_random(expr): # expr isn't random?
  639. return expr
  640. kwargs['numsamples'] = numsamples
  641. from sympy.stats.symbolic_probability import Expectation
  642. if evaluate:
  643. return Expectation(expr, condition).doit(**kwargs)
  644. return Expectation(expr, condition)
  645. def probability(condition, given_condition=None, numsamples=None,
  646. evaluate=True, **kwargs):
  647. """
  648. Probability that a condition is true, optionally given a second condition.
  649. Parameters
  650. ==========
  651. condition : Combination of Relationals containing RandomSymbols
  652. The condition of which you want to compute the probability
  653. given_condition : Combination of Relationals containing RandomSymbols
  654. A conditional expression. P(X > 1, X > 0) is expectation of X > 1
  655. given X > 0
  656. numsamples : int
  657. Enables sampling and approximates the probability with this many samples
  658. evaluate : Bool (defaults to True)
  659. In case of continuous systems return unevaluated integral
  660. Examples
  661. ========
  662. >>> from sympy.stats import P, Die
  663. >>> from sympy import Eq
  664. >>> X, Y = Die('X', 6), Die('Y', 6)
  665. >>> P(X > 3)
  666. 1/2
  667. >>> P(Eq(X, 5), X > 2) # Probability that X == 5 given that X > 2
  668. 1/4
  669. >>> P(X > Y)
  670. 5/12
  671. """
  672. kwargs['numsamples'] = numsamples
  673. from sympy.stats.symbolic_probability import Probability
  674. if evaluate:
  675. return Probability(condition, given_condition).doit(**kwargs)
  676. return Probability(condition, given_condition)
  677. class Density(Basic):
  678. expr = property(lambda self: self.args[0])
  679. def __new__(cls, expr, condition = None):
  680. expr = _sympify(expr)
  681. if condition is None:
  682. obj = Basic.__new__(cls, expr)
  683. else:
  684. condition = _sympify(condition)
  685. obj = Basic.__new__(cls, expr, condition)
  686. return obj
  687. @property
  688. def condition(self):
  689. if len(self.args) > 1:
  690. return self.args[1]
  691. else:
  692. return None
  693. def doit(self, evaluate=True, **kwargs):
  694. from sympy.stats.random_matrix import RandomMatrixPSpace
  695. from sympy.stats.joint_rv import JointPSpace
  696. from sympy.stats.matrix_distributions import MatrixPSpace
  697. from sympy.stats.compound_rv import CompoundPSpace
  698. from sympy.stats.frv import SingleFiniteDistribution
  699. expr, condition = self.expr, self.condition
  700. if isinstance(expr, SingleFiniteDistribution):
  701. return expr.dict
  702. if condition is not None:
  703. # Recompute on new conditional expr
  704. expr = given(expr, condition, **kwargs)
  705. if not random_symbols(expr):
  706. return Lambda(x, DiracDelta(x - expr))
  707. if isinstance(expr, RandomSymbol):
  708. if isinstance(expr.pspace, (SinglePSpace, JointPSpace, MatrixPSpace)) and \
  709. hasattr(expr.pspace, 'distribution'):
  710. return expr.pspace.distribution
  711. elif isinstance(expr.pspace, RandomMatrixPSpace):
  712. return expr.pspace.model
  713. if isinstance(pspace(expr), CompoundPSpace):
  714. kwargs['compound_evaluate'] = evaluate
  715. result = pspace(expr).compute_density(expr, **kwargs)
  716. if evaluate and hasattr(result, 'doit'):
  717. return result.doit()
  718. else:
  719. return result
  720. def density(expr, condition=None, evaluate=True, numsamples=None, **kwargs):
  721. """
  722. Probability density of a random expression, optionally given a second
  723. condition.
  724. Explanation
  725. ===========
  726. This density will take on different forms for different types of
  727. probability spaces. Discrete variables produce Dicts. Continuous
  728. variables produce Lambdas.
  729. Parameters
  730. ==========
  731. expr : Expr containing RandomSymbols
  732. The expression of which you want to compute the density value
  733. condition : Relational containing RandomSymbols
  734. A conditional expression. density(X > 1, X > 0) is density of X > 1
  735. given X > 0
  736. numsamples : int
  737. Enables sampling and approximates the density with this many samples
  738. Examples
  739. ========
  740. >>> from sympy.stats import density, Die, Normal
  741. >>> from sympy import Symbol
  742. >>> x = Symbol('x')
  743. >>> D = Die('D', 6)
  744. >>> X = Normal(x, 0, 1)
  745. >>> density(D).dict
  746. {1: 1/6, 2: 1/6, 3: 1/6, 4: 1/6, 5: 1/6, 6: 1/6}
  747. >>> density(2*D).dict
  748. {2: 1/6, 4: 1/6, 6: 1/6, 8: 1/6, 10: 1/6, 12: 1/6}
  749. >>> density(X)(x)
  750. sqrt(2)*exp(-x**2/2)/(2*sqrt(pi))
  751. """
  752. if numsamples:
  753. return sampling_density(expr, condition, numsamples=numsamples,
  754. **kwargs)
  755. return Density(expr, condition).doit(evaluate=evaluate, **kwargs)
  756. def cdf(expr, condition=None, evaluate=True, **kwargs):
  757. """
  758. Cumulative Distribution Function of a random expression.
  759. optionally given a second condition.
  760. Explanation
  761. ===========
  762. This density will take on different forms for different types of
  763. probability spaces.
  764. Discrete variables produce Dicts.
  765. Continuous variables produce Lambdas.
  766. Examples
  767. ========
  768. >>> from sympy.stats import density, Die, Normal, cdf
  769. >>> D = Die('D', 6)
  770. >>> X = Normal('X', 0, 1)
  771. >>> density(D).dict
  772. {1: 1/6, 2: 1/6, 3: 1/6, 4: 1/6, 5: 1/6, 6: 1/6}
  773. >>> cdf(D)
  774. {1: 1/6, 2: 1/3, 3: 1/2, 4: 2/3, 5: 5/6, 6: 1}
  775. >>> cdf(3*D, D > 2)
  776. {9: 1/4, 12: 1/2, 15: 3/4, 18: 1}
  777. >>> cdf(X)
  778. Lambda(_z, erf(sqrt(2)*_z/2)/2 + 1/2)
  779. """
  780. if condition is not None: # If there is a condition
  781. # Recompute on new conditional expr
  782. return cdf(given(expr, condition, **kwargs), **kwargs)
  783. # Otherwise pass work off to the ProbabilitySpace
  784. result = pspace(expr).compute_cdf(expr, **kwargs)
  785. if evaluate and hasattr(result, 'doit'):
  786. return result.doit()
  787. else:
  788. return result
  789. def characteristic_function(expr, condition=None, evaluate=True, **kwargs):
  790. """
  791. Characteristic function of a random expression, optionally given a second condition.
  792. Returns a Lambda.
  793. Examples
  794. ========
  795. >>> from sympy.stats import Normal, DiscreteUniform, Poisson, characteristic_function
  796. >>> X = Normal('X', 0, 1)
  797. >>> characteristic_function(X)
  798. Lambda(_t, exp(-_t**2/2))
  799. >>> Y = DiscreteUniform('Y', [1, 2, 7])
  800. >>> characteristic_function(Y)
  801. Lambda(_t, exp(7*_t*I)/3 + exp(2*_t*I)/3 + exp(_t*I)/3)
  802. >>> Z = Poisson('Z', 2)
  803. >>> characteristic_function(Z)
  804. Lambda(_t, exp(2*exp(_t*I) - 2))
  805. """
  806. if condition is not None:
  807. return characteristic_function(given(expr, condition, **kwargs), **kwargs)
  808. result = pspace(expr).compute_characteristic_function(expr, **kwargs)
  809. if evaluate and hasattr(result, 'doit'):
  810. return result.doit()
  811. else:
  812. return result
  813. def moment_generating_function(expr, condition=None, evaluate=True, **kwargs):
  814. if condition is not None:
  815. return moment_generating_function(given(expr, condition, **kwargs), **kwargs)
  816. result = pspace(expr).compute_moment_generating_function(expr, **kwargs)
  817. if evaluate and hasattr(result, 'doit'):
  818. return result.doit()
  819. else:
  820. return result
  821. def where(condition, given_condition=None, **kwargs):
  822. """
  823. Returns the domain where a condition is True.
  824. Examples
  825. ========
  826. >>> from sympy.stats import where, Die, Normal
  827. >>> from sympy import And
  828. >>> D1, D2 = Die('a', 6), Die('b', 6)
  829. >>> a, b = D1.symbol, D2.symbol
  830. >>> X = Normal('x', 0, 1)
  831. >>> where(X**2<1)
  832. Domain: (-1 < x) & (x < 1)
  833. >>> where(X**2<1).set
  834. Interval.open(-1, 1)
  835. >>> where(And(D1<=D2, D2<3))
  836. Domain: (Eq(a, 1) & Eq(b, 1)) | (Eq(a, 1) & Eq(b, 2)) | (Eq(a, 2) & Eq(b, 2))
  837. """
  838. if given_condition is not None: # If there is a condition
  839. # Recompute on new conditional expr
  840. return where(given(condition, given_condition, **kwargs), **kwargs)
  841. # Otherwise pass work off to the ProbabilitySpace
  842. return pspace(condition).where(condition, **kwargs)
  843. @doctest_depends_on(modules=('scipy',))
  844. def sample(expr, condition=None, size=(), library='scipy',
  845. numsamples=1, seed=None, **kwargs):
  846. """
  847. A realization of the random expression.
  848. Parameters
  849. ==========
  850. expr : Expression of random variables
  851. Expression from which sample is extracted
  852. condition : Expr containing RandomSymbols
  853. A conditional expression
  854. size : int, tuple
  855. Represents size of each sample in numsamples
  856. library : str
  857. - 'scipy' : Sample using scipy
  858. - 'numpy' : Sample using numpy
  859. - 'pymc' : Sample using PyMC
  860. Choose any of the available options to sample from as string,
  861. by default is 'scipy'
  862. numsamples : int
  863. Number of samples, each with size as ``size``.
  864. .. deprecated:: 1.9
  865. The ``numsamples`` parameter is deprecated and is only provided for
  866. compatibility with v1.8. Use a list comprehension or an additional
  867. dimension in ``size`` instead. See
  868. :ref:`deprecated-sympy-stats-numsamples` for details.
  869. seed :
  870. An object to be used as seed by the given external library for sampling `expr`.
  871. Following is the list of possible types of object for the supported libraries,
  872. - 'scipy': int, numpy.random.RandomState, numpy.random.Generator
  873. - 'numpy': int, numpy.random.RandomState, numpy.random.Generator
  874. - 'pymc': int
  875. Optional, by default None, in which case seed settings
  876. related to the given library will be used.
  877. No modifications to environment's global seed settings
  878. are done by this argument.
  879. Returns
  880. =======
  881. sample: float/list/numpy.ndarray
  882. one sample or a collection of samples of the random expression.
  883. - sample(X) returns float/numpy.float64/numpy.int64 object.
  884. - sample(X, size=int/tuple) returns numpy.ndarray object.
  885. Examples
  886. ========
  887. >>> from sympy.stats import Die, sample, Normal, Geometric
  888. >>> X, Y, Z = Die('X', 6), Die('Y', 6), Die('Z', 6) # Finite Random Variable
  889. >>> die_roll = sample(X + Y + Z)
  890. >>> die_roll # doctest: +SKIP
  891. 3
  892. >>> N = Normal('N', 3, 4) # Continuous Random Variable
  893. >>> samp = sample(N)
  894. >>> samp in N.pspace.domain.set
  895. True
  896. >>> samp = sample(N, N>0)
  897. >>> samp > 0
  898. True
  899. >>> samp_list = sample(N, size=4)
  900. >>> [sam in N.pspace.domain.set for sam in samp_list]
  901. [True, True, True, True]
  902. >>> sample(N, size = (2,3)) # doctest: +SKIP
  903. array([[5.42519758, 6.40207856, 4.94991743],
  904. [1.85819627, 6.83403519, 1.9412172 ]])
  905. >>> G = Geometric('G', 0.5) # Discrete Random Variable
  906. >>> samp_list = sample(G, size=3)
  907. >>> samp_list # doctest: +SKIP
  908. [1, 3, 2]
  909. >>> [sam in G.pspace.domain.set for sam in samp_list]
  910. [True, True, True]
  911. >>> MN = Normal("MN", [3, 4], [[2, 1], [1, 2]]) # Joint Random Variable
  912. >>> samp_list = sample(MN, size=4)
  913. >>> samp_list # doctest: +SKIP
  914. [array([2.85768055, 3.38954165]),
  915. array([4.11163337, 4.3176591 ]),
  916. array([0.79115232, 1.63232916]),
  917. array([4.01747268, 3.96716083])]
  918. >>> [tuple(sam) in MN.pspace.domain.set for sam in samp_list]
  919. [True, True, True, True]
  920. .. versionchanged:: 1.7.0
  921. sample used to return an iterator containing the samples instead of value.
  922. .. versionchanged:: 1.9.0
  923. sample returns values or array of values instead of an iterator and numsamples is deprecated.
  924. """
  925. iterator = sample_iter(expr, condition, size=size, library=library,
  926. numsamples=numsamples, seed=seed)
  927. if numsamples != 1:
  928. sympy_deprecation_warning(
  929. f"""
  930. The numsamples parameter to sympy.stats.sample() is deprecated.
  931. Either use a list comprehension, like
  932. [sample(...) for i in range({numsamples})]
  933. or add a dimension to size, like
  934. sample(..., size={(numsamples,) + size})
  935. """,
  936. deprecated_since_version="1.9",
  937. active_deprecations_target="deprecated-sympy-stats-numsamples",
  938. )
  939. return [next(iterator) for i in range(numsamples)]
  940. return next(iterator)
  941. def quantile(expr, evaluate=True, **kwargs):
  942. r"""
  943. Return the :math:`p^{th}` order quantile of a probability distribution.
  944. Explanation
  945. ===========
  946. Quantile is defined as the value at which the probability of the random
  947. variable is less than or equal to the given probability.
  948. .. math::
  949. Q(p) = \inf\{x \in (-\infty, \infty) : p \le F(x)\}
  950. Examples
  951. ========
  952. >>> from sympy.stats import quantile, Die, Exponential
  953. >>> from sympy import Symbol, pprint
  954. >>> p = Symbol("p")
  955. >>> l = Symbol("lambda", positive=True)
  956. >>> X = Exponential("x", l)
  957. >>> quantile(X)(p)
  958. -log(1 - p)/lambda
  959. >>> D = Die("d", 6)
  960. >>> pprint(quantile(D)(p), use_unicode=False)
  961. /nan for Or(p > 1, p < 0)
  962. |
  963. | 1 for p <= 1/6
  964. |
  965. | 2 for p <= 1/3
  966. |
  967. < 3 for p <= 1/2
  968. |
  969. | 4 for p <= 2/3
  970. |
  971. | 5 for p <= 5/6
  972. |
  973. \ 6 for p <= 1
  974. """
  975. result = pspace(expr).compute_quantile(expr, **kwargs)
  976. if evaluate and hasattr(result, 'doit'):
  977. return result.doit()
  978. else:
  979. return result
  980. def sample_iter(expr, condition=None, size=(), library='scipy',
  981. numsamples=S.Infinity, seed=None, **kwargs):
  982. """
  983. Returns an iterator of realizations from the expression given a condition.
  984. Parameters
  985. ==========
  986. expr: Expr
  987. Random expression to be realized
  988. condition: Expr, optional
  989. A conditional expression
  990. size : int, tuple
  991. Represents size of each sample in numsamples
  992. numsamples: integer, optional
  993. Length of the iterator (defaults to infinity)
  994. seed :
  995. An object to be used as seed by the given external library for sampling `expr`.
  996. Following is the list of possible types of object for the supported libraries,
  997. - 'scipy': int, numpy.random.RandomState, numpy.random.Generator
  998. - 'numpy': int, numpy.random.RandomState, numpy.random.Generator
  999. - 'pymc': int
  1000. Optional, by default None, in which case seed settings
  1001. related to the given library will be used.
  1002. No modifications to environment's global seed settings
  1003. are done by this argument.
  1004. Examples
  1005. ========
  1006. >>> from sympy.stats import Normal, sample_iter
  1007. >>> X = Normal('X', 0, 1)
  1008. >>> expr = X*X + 3
  1009. >>> iterator = sample_iter(expr, numsamples=3) # doctest: +SKIP
  1010. >>> list(iterator) # doctest: +SKIP
  1011. [12, 4, 7]
  1012. Returns
  1013. =======
  1014. sample_iter: iterator object
  1015. iterator object containing the sample/samples of given expr
  1016. See Also
  1017. ========
  1018. sample
  1019. sampling_P
  1020. sampling_E
  1021. """
  1022. from sympy.stats.joint_rv import JointRandomSymbol
  1023. if not import_module(library):
  1024. raise ValueError("Failed to import %s" % library)
  1025. if condition is not None:
  1026. ps = pspace(Tuple(expr, condition))
  1027. else:
  1028. ps = pspace(expr)
  1029. rvs = list(ps.values)
  1030. if isinstance(expr, JointRandomSymbol):
  1031. expr = expr.subs({expr: RandomSymbol(expr.symbol, expr.pspace)})
  1032. else:
  1033. sub = {}
  1034. for arg in expr.args:
  1035. if isinstance(arg, JointRandomSymbol):
  1036. sub[arg] = RandomSymbol(arg.symbol, arg.pspace)
  1037. expr = expr.subs(sub)
  1038. def fn_subs(*args):
  1039. return expr.subs(dict(zip(rvs, args)))
  1040. def given_fn_subs(*args):
  1041. if condition is not None:
  1042. return condition.subs(dict(zip(rvs, args)))
  1043. return False
  1044. if library in ('pymc', 'pymc3'):
  1045. # Currently unable to lambdify in pymc
  1046. # TODO : Remove when lambdify accepts 'pymc' as module
  1047. fn = lambdify(rvs, expr, **kwargs)
  1048. else:
  1049. fn = lambdify(rvs, expr, modules=library, **kwargs)
  1050. if condition is not None:
  1051. given_fn = lambdify(rvs, condition, **kwargs)
  1052. def return_generator_infinite():
  1053. count = 0
  1054. _size = (1,)+((size,) if isinstance(size, int) else size)
  1055. while count < numsamples:
  1056. d = ps.sample(size=_size, library=library, seed=seed) # a dictionary that maps RVs to values
  1057. args = [d[rv][0] for rv in rvs]
  1058. if condition is not None: # Check that these values satisfy the condition
  1059. # TODO: Replace the try-except block with only given_fn(*args)
  1060. # once lambdify works with unevaluated SymPy objects.
  1061. try:
  1062. gd = given_fn(*args)
  1063. except (NameError, TypeError):
  1064. gd = given_fn_subs(*args)
  1065. if gd != True and gd != False:
  1066. raise ValueError(
  1067. "Conditions must not contain free symbols")
  1068. if not gd: # If the values don't satisfy then try again
  1069. continue
  1070. yield fn(*args)
  1071. count += 1
  1072. def return_generator_finite():
  1073. faulty = True
  1074. while faulty:
  1075. d = ps.sample(size=(numsamples,) + ((size,) if isinstance(size, int) else size),
  1076. library=library, seed=seed) # a dictionary that maps RVs to values
  1077. faulty = False
  1078. count = 0
  1079. while count < numsamples and not faulty:
  1080. args = [d[rv][count] for rv in rvs]
  1081. if condition is not None: # Check that these values satisfy the condition
  1082. # TODO: Replace the try-except block with only given_fn(*args)
  1083. # once lambdify works with unevaluated SymPy objects.
  1084. try:
  1085. gd = given_fn(*args)
  1086. except (NameError, TypeError):
  1087. gd = given_fn_subs(*args)
  1088. if gd != True and gd != False:
  1089. raise ValueError(
  1090. "Conditions must not contain free symbols")
  1091. if not gd: # If the values don't satisfy then try again
  1092. faulty = True
  1093. count += 1
  1094. count = 0
  1095. while count < numsamples:
  1096. args = [d[rv][count] for rv in rvs]
  1097. # TODO: Replace the try-except block with only fn(*args)
  1098. # once lambdify works with unevaluated SymPy objects.
  1099. try:
  1100. yield fn(*args)
  1101. except (NameError, TypeError):
  1102. yield fn_subs(*args)
  1103. count += 1
  1104. if numsamples is S.Infinity:
  1105. return return_generator_infinite()
  1106. return return_generator_finite()
  1107. def sample_iter_lambdify(expr, condition=None, size=(),
  1108. numsamples=S.Infinity, seed=None, **kwargs):
  1109. return sample_iter(expr, condition=condition, size=size,
  1110. numsamples=numsamples, seed=seed, **kwargs)
  1111. def sample_iter_subs(expr, condition=None, size=(),
  1112. numsamples=S.Infinity, seed=None, **kwargs):
  1113. return sample_iter(expr, condition=condition, size=size,
  1114. numsamples=numsamples, seed=seed, **kwargs)
  1115. def sampling_P(condition, given_condition=None, library='scipy', numsamples=1,
  1116. evalf=True, seed=None, **kwargs):
  1117. """
  1118. Sampling version of P.
  1119. See Also
  1120. ========
  1121. P
  1122. sampling_E
  1123. sampling_density
  1124. """
  1125. count_true = 0
  1126. count_false = 0
  1127. samples = sample_iter(condition, given_condition, library=library,
  1128. numsamples=numsamples, seed=seed, **kwargs)
  1129. for sample in samples:
  1130. if sample:
  1131. count_true += 1
  1132. else:
  1133. count_false += 1
  1134. result = S(count_true) / numsamples
  1135. if evalf:
  1136. return result.evalf()
  1137. else:
  1138. return result
  1139. def sampling_E(expr, given_condition=None, library='scipy', numsamples=1,
  1140. evalf=True, seed=None, **kwargs):
  1141. """
  1142. Sampling version of E.
  1143. See Also
  1144. ========
  1145. P
  1146. sampling_P
  1147. sampling_density
  1148. """
  1149. samples = list(sample_iter(expr, given_condition, library=library,
  1150. numsamples=numsamples, seed=seed, **kwargs))
  1151. result = Add(*samples) / numsamples
  1152. if evalf:
  1153. return result.evalf()
  1154. else:
  1155. return result
  1156. def sampling_density(expr, given_condition=None, library='scipy',
  1157. numsamples=1, seed=None, **kwargs):
  1158. """
  1159. Sampling version of density.
  1160. See Also
  1161. ========
  1162. density
  1163. sampling_P
  1164. sampling_E
  1165. """
  1166. results = {}
  1167. for result in sample_iter(expr, given_condition, library=library,
  1168. numsamples=numsamples, seed=seed, **kwargs):
  1169. results[result] = results.get(result, 0) + 1
  1170. return results
  1171. def dependent(a, b):
  1172. """
  1173. Dependence of two random expressions.
  1174. Two expressions are independent if knowledge of one does not change
  1175. computations on the other.
  1176. Examples
  1177. ========
  1178. >>> from sympy.stats import Normal, dependent, given
  1179. >>> from sympy import Tuple, Eq
  1180. >>> X, Y = Normal('X', 0, 1), Normal('Y', 0, 1)
  1181. >>> dependent(X, Y)
  1182. False
  1183. >>> dependent(2*X + Y, -Y)
  1184. True
  1185. >>> X, Y = given(Tuple(X, Y), Eq(X + Y, 3))
  1186. >>> dependent(X, Y)
  1187. True
  1188. See Also
  1189. ========
  1190. independent
  1191. """
  1192. if pspace_independent(a, b):
  1193. return False
  1194. z = Symbol('z', real=True)
  1195. # Dependent if density is unchanged when one is given information about
  1196. # the other
  1197. return (density(a, Eq(b, z)) != density(a) or
  1198. density(b, Eq(a, z)) != density(b))
  1199. def independent(a, b):
  1200. """
  1201. Independence of two random expressions.
  1202. Two expressions are independent if knowledge of one does not change
  1203. computations on the other.
  1204. Examples
  1205. ========
  1206. >>> from sympy.stats import Normal, independent, given
  1207. >>> from sympy import Tuple, Eq
  1208. >>> X, Y = Normal('X', 0, 1), Normal('Y', 0, 1)
  1209. >>> independent(X, Y)
  1210. True
  1211. >>> independent(2*X + Y, -Y)
  1212. False
  1213. >>> X, Y = given(Tuple(X, Y), Eq(X + Y, 3))
  1214. >>> independent(X, Y)
  1215. False
  1216. See Also
  1217. ========
  1218. dependent
  1219. """
  1220. return not dependent(a, b)
  1221. def pspace_independent(a, b):
  1222. """
  1223. Tests for independence between a and b by checking if their PSpaces have
  1224. overlapping symbols. This is a sufficient but not necessary condition for
  1225. independence and is intended to be used internally.
  1226. Notes
  1227. =====
  1228. pspace_independent(a, b) implies independent(a, b)
  1229. independent(a, b) does not imply pspace_independent(a, b)
  1230. """
  1231. a_symbols = set(pspace(b).symbols)
  1232. b_symbols = set(pspace(a).symbols)
  1233. if len(set(random_symbols(a)).intersection(random_symbols(b))) != 0:
  1234. return False
  1235. if len(a_symbols.intersection(b_symbols)) == 0:
  1236. return True
  1237. return None
  1238. def rv_subs(expr, symbols=None):
  1239. """
  1240. Given a random expression replace all random variables with their symbols.
  1241. If symbols keyword is given restrict the swap to only the symbols listed.
  1242. """
  1243. if symbols is None:
  1244. symbols = random_symbols(expr)
  1245. if not symbols:
  1246. return expr
  1247. swapdict = {rv: rv.symbol for rv in symbols}
  1248. return expr.subs(swapdict)
  1249. class NamedArgsMixin:
  1250. _argnames: tuple[str, ...] = ()
  1251. def __getattr__(self, attr):
  1252. try:
  1253. return self.args[self._argnames.index(attr)]
  1254. except ValueError:
  1255. raise AttributeError("'%s' object has no attribute '%s'" % (
  1256. type(self).__name__, attr))
  1257. class Distribution(Basic):
  1258. def sample(self, size=(), library='scipy', seed=None):
  1259. """ A random realization from the distribution """
  1260. module = import_module(library)
  1261. if library in {'scipy', 'numpy', 'pymc3', 'pymc'} and module is None:
  1262. raise ValueError("Failed to import %s" % library)
  1263. if library == 'scipy':
  1264. # scipy does not require map as it can handle using custom distributions.
  1265. # However, we will still use a map where we can.
  1266. # TODO: do this for drv.py and frv.py if necessary.
  1267. # TODO: add more distributions here if there are more
  1268. # See links below referring to sections beginning with "A common parametrization..."
  1269. # I will remove all these comments if everything is ok.
  1270. from sympy.stats.sampling.sample_scipy import do_sample_scipy
  1271. import numpy
  1272. if seed is None or isinstance(seed, int):
  1273. rand_state = numpy.random.default_rng(seed=seed)
  1274. else:
  1275. rand_state = seed
  1276. samps = do_sample_scipy(self, size, rand_state)
  1277. elif library == 'numpy':
  1278. from sympy.stats.sampling.sample_numpy import do_sample_numpy
  1279. import numpy
  1280. if seed is None or isinstance(seed, int):
  1281. rand_state = numpy.random.default_rng(seed=seed)
  1282. else:
  1283. rand_state = seed
  1284. _size = None if size == () else size
  1285. samps = do_sample_numpy(self, _size, rand_state)
  1286. elif library in ('pymc', 'pymc3'):
  1287. from sympy.stats.sampling.sample_pymc import do_sample_pymc
  1288. import logging
  1289. logging.getLogger("pymc").setLevel(logging.ERROR)
  1290. try:
  1291. import pymc
  1292. except ImportError:
  1293. import pymc3 as pymc
  1294. with pymc.Model():
  1295. if do_sample_pymc(self) is not None:
  1296. samps = pymc.sample(draws=prod(size), chains=1, compute_convergence_checks=False,
  1297. progressbar=False, random_seed=seed, return_inferencedata=False)[:]['X']
  1298. samps = samps.reshape(size)
  1299. else:
  1300. samps = None
  1301. else:
  1302. raise NotImplementedError("Sampling from %s is not supported yet."
  1303. % str(library))
  1304. if samps is not None:
  1305. return samps
  1306. raise NotImplementedError(
  1307. "Sampling for %s is not currently implemented from %s"
  1308. % (self, library))
  1309. def _value_check(condition, message):
  1310. """
  1311. Raise a ValueError with message if condition is False, else
  1312. return True if all conditions were True, else False.
  1313. Examples
  1314. ========
  1315. >>> from sympy.stats.rv import _value_check
  1316. >>> from sympy.abc import a, b, c
  1317. >>> from sympy import And, Dummy
  1318. >>> _value_check(2 < 3, '')
  1319. True
  1320. Here, the condition is not False, but it does not evaluate to True
  1321. so False is returned (but no error is raised). So checking if the
  1322. return value is True or False will tell you if all conditions were
  1323. evaluated.
  1324. >>> _value_check(a < b, '')
  1325. False
  1326. In this case the condition is False so an error is raised:
  1327. >>> r = Dummy(real=True)
  1328. >>> _value_check(r < r - 1, 'condition is not true')
  1329. Traceback (most recent call last):
  1330. ...
  1331. ValueError: condition is not true
  1332. If no condition of many conditions must be False, they can be
  1333. checked by passing them as an iterable:
  1334. >>> _value_check((a < 0, b < 0, c < 0), '')
  1335. False
  1336. The iterable can be a generator, too:
  1337. >>> _value_check((i < 0 for i in (a, b, c)), '')
  1338. False
  1339. The following are equivalent to the above but do not pass
  1340. an iterable:
  1341. >>> all(_value_check(i < 0, '') for i in (a, b, c))
  1342. False
  1343. >>> _value_check(And(a < 0, b < 0, c < 0), '')
  1344. False
  1345. """
  1346. if not iterable(condition):
  1347. condition = [condition]
  1348. truth = fuzzy_and(condition)
  1349. if truth == False:
  1350. raise ValueError(message)
  1351. return truth == True
  1352. def _symbol_converter(sym):
  1353. """
  1354. Casts the parameter to Symbol if it is 'str'
  1355. otherwise no operation is performed on it.
  1356. Parameters
  1357. ==========
  1358. sym
  1359. The parameter to be converted.
  1360. Returns
  1361. =======
  1362. Symbol
  1363. the parameter converted to Symbol.
  1364. Raises
  1365. ======
  1366. TypeError
  1367. If the parameter is not an instance of both str and
  1368. Symbol.
  1369. Examples
  1370. ========
  1371. >>> from sympy import Symbol
  1372. >>> from sympy.stats.rv import _symbol_converter
  1373. >>> s = _symbol_converter('s')
  1374. >>> isinstance(s, Symbol)
  1375. True
  1376. >>> _symbol_converter(1)
  1377. Traceback (most recent call last):
  1378. ...
  1379. TypeError: 1 is neither a Symbol nor a string
  1380. >>> r = Symbol('r')
  1381. >>> isinstance(r, Symbol)
  1382. True
  1383. """
  1384. if isinstance(sym, str):
  1385. sym = Symbol(sym)
  1386. if not isinstance(sym, Symbol):
  1387. raise TypeError("%s is neither a Symbol nor a string"%(sym))
  1388. return sym
  1389. def sample_stochastic_process(process):
  1390. """
  1391. This function is used to sample from stochastic process.
  1392. Parameters
  1393. ==========
  1394. process: StochasticProcess
  1395. Process used to extract the samples. It must be an instance of
  1396. StochasticProcess
  1397. Examples
  1398. ========
  1399. >>> from sympy.stats import sample_stochastic_process, DiscreteMarkovChain
  1400. >>> from sympy import Matrix
  1401. >>> T = Matrix([[0.5, 0.2, 0.3],[0.2, 0.5, 0.3],[0.2, 0.3, 0.5]])
  1402. >>> Y = DiscreteMarkovChain("Y", [0, 1, 2], T)
  1403. >>> next(sample_stochastic_process(Y)) in Y.state_space
  1404. True
  1405. >>> next(sample_stochastic_process(Y)) # doctest: +SKIP
  1406. 0
  1407. >>> next(sample_stochastic_process(Y)) # doctest: +SKIP
  1408. 2
  1409. Returns
  1410. =======
  1411. sample: iterator object
  1412. iterator object containing the sample of given process
  1413. """
  1414. from sympy.stats.stochastic_process_types import StochasticProcess
  1415. if not isinstance(process, StochasticProcess):
  1416. raise ValueError("Process must be an instance of Stochastic Process")
  1417. return process.sample()