test_utils.py 69 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929
  1. import warnings
  2. import sys
  3. import os
  4. import itertools
  5. import pytest
  6. import weakref
  7. import re
  8. import numpy as np
  9. import numpy._core._multiarray_umath as ncu
  10. from numpy.testing import (
  11. assert_equal, assert_array_equal, assert_almost_equal,
  12. assert_array_almost_equal, assert_array_less, build_err_msg,
  13. assert_raises, assert_warns, assert_no_warnings, assert_allclose,
  14. assert_approx_equal, assert_array_almost_equal_nulp, assert_array_max_ulp,
  15. clear_and_catch_warnings, suppress_warnings, assert_string_equal, assert_,
  16. tempdir, temppath, assert_no_gc_cycles, HAS_REFCOUNT
  17. )
  18. class _GenericTest:
  19. def _test_equal(self, a, b):
  20. self._assert_func(a, b)
  21. def _test_not_equal(self, a, b):
  22. with assert_raises(AssertionError):
  23. self._assert_func(a, b)
  24. def test_array_rank1_eq(self):
  25. """Test two equal array of rank 1 are found equal."""
  26. a = np.array([1, 2])
  27. b = np.array([1, 2])
  28. self._test_equal(a, b)
  29. def test_array_rank1_noteq(self):
  30. """Test two different array of rank 1 are found not equal."""
  31. a = np.array([1, 2])
  32. b = np.array([2, 2])
  33. self._test_not_equal(a, b)
  34. def test_array_rank2_eq(self):
  35. """Test two equal array of rank 2 are found equal."""
  36. a = np.array([[1, 2], [3, 4]])
  37. b = np.array([[1, 2], [3, 4]])
  38. self._test_equal(a, b)
  39. def test_array_diffshape(self):
  40. """Test two arrays with different shapes are found not equal."""
  41. a = np.array([1, 2])
  42. b = np.array([[1, 2], [1, 2]])
  43. self._test_not_equal(a, b)
  44. def test_objarray(self):
  45. """Test object arrays."""
  46. a = np.array([1, 1], dtype=object)
  47. self._test_equal(a, 1)
  48. def test_array_likes(self):
  49. self._test_equal([1, 2, 3], (1, 2, 3))
  50. class TestArrayEqual(_GenericTest):
  51. def setup_method(self):
  52. self._assert_func = assert_array_equal
  53. def test_generic_rank1(self):
  54. """Test rank 1 array for all dtypes."""
  55. def foo(t):
  56. a = np.empty(2, t)
  57. a.fill(1)
  58. b = a.copy()
  59. c = a.copy()
  60. c.fill(0)
  61. self._test_equal(a, b)
  62. self._test_not_equal(c, b)
  63. # Test numeric types and object
  64. for t in '?bhilqpBHILQPfdgFDG':
  65. foo(t)
  66. # Test strings
  67. for t in ['S1', 'U1']:
  68. foo(t)
  69. def test_0_ndim_array(self):
  70. x = np.array(473963742225900817127911193656584771)
  71. y = np.array(18535119325151578301457182298393896)
  72. with pytest.raises(AssertionError) as exc_info:
  73. self._assert_func(x, y)
  74. msg = str(exc_info.value)
  75. assert_('Mismatched elements: 1 / 1 (100%)\n'
  76. in msg)
  77. y = x
  78. self._assert_func(x, y)
  79. x = np.array(4395065348745.5643764887869876)
  80. y = np.array(0)
  81. expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
  82. 'Max absolute difference among violations: '
  83. '4.39506535e+12\n'
  84. 'Max relative difference among violations: inf\n')
  85. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  86. self._assert_func(x, y)
  87. x = y
  88. self._assert_func(x, y)
  89. def test_generic_rank3(self):
  90. """Test rank 3 array for all dtypes."""
  91. def foo(t):
  92. a = np.empty((4, 2, 3), t)
  93. a.fill(1)
  94. b = a.copy()
  95. c = a.copy()
  96. c.fill(0)
  97. self._test_equal(a, b)
  98. self._test_not_equal(c, b)
  99. # Test numeric types and object
  100. for t in '?bhilqpBHILQPfdgFDG':
  101. foo(t)
  102. # Test strings
  103. for t in ['S1', 'U1']:
  104. foo(t)
  105. def test_nan_array(self):
  106. """Test arrays with nan values in them."""
  107. a = np.array([1, 2, np.nan])
  108. b = np.array([1, 2, np.nan])
  109. self._test_equal(a, b)
  110. c = np.array([1, 2, 3])
  111. self._test_not_equal(c, b)
  112. def test_string_arrays(self):
  113. """Test two arrays with different shapes are found not equal."""
  114. a = np.array(['floupi', 'floupa'])
  115. b = np.array(['floupi', 'floupa'])
  116. self._test_equal(a, b)
  117. c = np.array(['floupipi', 'floupa'])
  118. self._test_not_equal(c, b)
  119. def test_recarrays(self):
  120. """Test record arrays."""
  121. a = np.empty(2, [('floupi', float), ('floupa', float)])
  122. a['floupi'] = [1, 2]
  123. a['floupa'] = [1, 2]
  124. b = a.copy()
  125. self._test_equal(a, b)
  126. c = np.empty(2, [('floupipi', float),
  127. ('floupi', float), ('floupa', float)])
  128. c['floupipi'] = a['floupi'].copy()
  129. c['floupa'] = a['floupa'].copy()
  130. with pytest.raises(TypeError):
  131. self._test_not_equal(c, b)
  132. def test_masked_nan_inf(self):
  133. # Regression test for gh-11121
  134. a = np.ma.MaskedArray([3., 4., 6.5], mask=[False, True, False])
  135. b = np.array([3., np.nan, 6.5])
  136. self._test_equal(a, b)
  137. self._test_equal(b, a)
  138. a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, False, False])
  139. b = np.array([np.inf, 4., 6.5])
  140. self._test_equal(a, b)
  141. self._test_equal(b, a)
  142. def test_subclass_that_overrides_eq(self):
  143. # While we cannot guarantee testing functions will always work for
  144. # subclasses, the tests should ideally rely only on subclasses having
  145. # comparison operators, not on them being able to store booleans
  146. # (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
  147. class MyArray(np.ndarray):
  148. def __eq__(self, other):
  149. return bool(np.equal(self, other).all())
  150. def __ne__(self, other):
  151. return not self == other
  152. a = np.array([1., 2.]).view(MyArray)
  153. b = np.array([2., 3.]).view(MyArray)
  154. assert_(type(a == a), bool)
  155. assert_(a == a)
  156. assert_(a != b)
  157. self._test_equal(a, a)
  158. self._test_not_equal(a, b)
  159. self._test_not_equal(b, a)
  160. expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
  161. 'Max absolute difference among violations: 1.\n'
  162. 'Max relative difference among violations: 0.5')
  163. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  164. self._test_equal(a, b)
  165. c = np.array([0., 2.9]).view(MyArray)
  166. expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
  167. 'Max absolute difference among violations: 2.\n'
  168. 'Max relative difference among violations: inf')
  169. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  170. self._test_equal(b, c)
  171. def test_subclass_that_does_not_implement_npall(self):
  172. class MyArray(np.ndarray):
  173. def __array_function__(self, *args, **kwargs):
  174. return NotImplemented
  175. a = np.array([1., 2.]).view(MyArray)
  176. b = np.array([2., 3.]).view(MyArray)
  177. with assert_raises(TypeError):
  178. np.all(a)
  179. self._test_equal(a, a)
  180. self._test_not_equal(a, b)
  181. self._test_not_equal(b, a)
  182. def test_suppress_overflow_warnings(self):
  183. # Based on issue #18992
  184. with pytest.raises(AssertionError):
  185. with np.errstate(all="raise"):
  186. np.testing.assert_array_equal(
  187. np.array([1, 2, 3], np.float32),
  188. np.array([1, 1e-40, 3], np.float32))
  189. def test_array_vs_scalar_is_equal(self):
  190. """Test comparing an array with a scalar when all values are equal."""
  191. a = np.array([1., 1., 1.])
  192. b = 1.
  193. self._test_equal(a, b)
  194. def test_array_vs_array_not_equal(self):
  195. """Test comparing an array with a scalar when not all values equal."""
  196. a = np.array([34986, 545676, 439655, 563766])
  197. b = np.array([34986, 545676, 439655, 0])
  198. expected_msg = ('Mismatched elements: 1 / 4 (25%)\n'
  199. 'Max absolute difference among violations: 563766\n'
  200. 'Max relative difference among violations: inf')
  201. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  202. self._assert_func(a, b)
  203. a = np.array([34986, 545676, 439655.2, 563766])
  204. expected_msg = ('Mismatched elements: 2 / 4 (50%)\n'
  205. 'Max absolute difference among violations: '
  206. '563766.\n'
  207. 'Max relative difference among violations: '
  208. '4.54902139e-07')
  209. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  210. self._assert_func(a, b)
  211. def test_array_vs_scalar_strict(self):
  212. """Test comparing an array with a scalar with strict option."""
  213. a = np.array([1., 1., 1.])
  214. b = 1.
  215. with pytest.raises(AssertionError):
  216. self._assert_func(a, b, strict=True)
  217. def test_array_vs_array_strict(self):
  218. """Test comparing two arrays with strict option."""
  219. a = np.array([1., 1., 1.])
  220. b = np.array([1., 1., 1.])
  221. self._assert_func(a, b, strict=True)
  222. def test_array_vs_float_array_strict(self):
  223. """Test comparing two arrays with strict option."""
  224. a = np.array([1, 1, 1])
  225. b = np.array([1., 1., 1.])
  226. with pytest.raises(AssertionError):
  227. self._assert_func(a, b, strict=True)
  228. class TestBuildErrorMessage:
  229. def test_build_err_msg_defaults(self):
  230. x = np.array([1.00001, 2.00002, 3.00003])
  231. y = np.array([1.00002, 2.00003, 3.00004])
  232. err_msg = 'There is a mismatch'
  233. a = build_err_msg([x, y], err_msg)
  234. b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array(['
  235. '1.00001, 2.00002, 3.00003])\n DESIRED: array([1.00002, '
  236. '2.00003, 3.00004])')
  237. assert_equal(a, b)
  238. def test_build_err_msg_no_verbose(self):
  239. x = np.array([1.00001, 2.00002, 3.00003])
  240. y = np.array([1.00002, 2.00003, 3.00004])
  241. err_msg = 'There is a mismatch'
  242. a = build_err_msg([x, y], err_msg, verbose=False)
  243. b = '\nItems are not equal: There is a mismatch'
  244. assert_equal(a, b)
  245. def test_build_err_msg_custom_names(self):
  246. x = np.array([1.00001, 2.00002, 3.00003])
  247. y = np.array([1.00002, 2.00003, 3.00004])
  248. err_msg = 'There is a mismatch'
  249. a = build_err_msg([x, y], err_msg, names=('FOO', 'BAR'))
  250. b = ('\nItems are not equal: There is a mismatch\n FOO: array(['
  251. '1.00001, 2.00002, 3.00003])\n BAR: array([1.00002, 2.00003, '
  252. '3.00004])')
  253. assert_equal(a, b)
  254. def test_build_err_msg_custom_precision(self):
  255. x = np.array([1.000000001, 2.00002, 3.00003])
  256. y = np.array([1.000000002, 2.00003, 3.00004])
  257. err_msg = 'There is a mismatch'
  258. a = build_err_msg([x, y], err_msg, precision=10)
  259. b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array(['
  260. '1.000000001, 2.00002 , 3.00003 ])\n DESIRED: array(['
  261. '1.000000002, 2.00003 , 3.00004 ])')
  262. assert_equal(a, b)
  263. class TestEqual(TestArrayEqual):
  264. def setup_method(self):
  265. self._assert_func = assert_equal
  266. def test_nan_items(self):
  267. self._assert_func(np.nan, np.nan)
  268. self._assert_func([np.nan], [np.nan])
  269. self._test_not_equal(np.nan, [np.nan])
  270. self._test_not_equal(np.nan, 1)
  271. def test_inf_items(self):
  272. self._assert_func(np.inf, np.inf)
  273. self._assert_func([np.inf], [np.inf])
  274. self._test_not_equal(np.inf, [np.inf])
  275. def test_datetime(self):
  276. self._test_equal(
  277. np.datetime64("2017-01-01", "s"),
  278. np.datetime64("2017-01-01", "s")
  279. )
  280. self._test_equal(
  281. np.datetime64("2017-01-01", "s"),
  282. np.datetime64("2017-01-01", "m")
  283. )
  284. # gh-10081
  285. self._test_not_equal(
  286. np.datetime64("2017-01-01", "s"),
  287. np.datetime64("2017-01-02", "s")
  288. )
  289. self._test_not_equal(
  290. np.datetime64("2017-01-01", "s"),
  291. np.datetime64("2017-01-02", "m")
  292. )
  293. def test_nat_items(self):
  294. # not a datetime
  295. nadt_no_unit = np.datetime64("NaT")
  296. nadt_s = np.datetime64("NaT", "s")
  297. nadt_d = np.datetime64("NaT", "ns")
  298. # not a timedelta
  299. natd_no_unit = np.timedelta64("NaT")
  300. natd_s = np.timedelta64("NaT", "s")
  301. natd_d = np.timedelta64("NaT", "ns")
  302. dts = [nadt_no_unit, nadt_s, nadt_d]
  303. tds = [natd_no_unit, natd_s, natd_d]
  304. for a, b in itertools.product(dts, dts):
  305. self._assert_func(a, b)
  306. self._assert_func([a], [b])
  307. self._test_not_equal([a], b)
  308. for a, b in itertools.product(tds, tds):
  309. self._assert_func(a, b)
  310. self._assert_func([a], [b])
  311. self._test_not_equal([a], b)
  312. for a, b in itertools.product(tds, dts):
  313. self._test_not_equal(a, b)
  314. self._test_not_equal(a, [b])
  315. self._test_not_equal([a], [b])
  316. self._test_not_equal([a], np.datetime64("2017-01-01", "s"))
  317. self._test_not_equal([b], np.datetime64("2017-01-01", "s"))
  318. self._test_not_equal([a], np.timedelta64(123, "s"))
  319. self._test_not_equal([b], np.timedelta64(123, "s"))
  320. def test_non_numeric(self):
  321. self._assert_func('ab', 'ab')
  322. self._test_not_equal('ab', 'abb')
  323. def test_complex_item(self):
  324. self._assert_func(complex(1, 2), complex(1, 2))
  325. self._assert_func(complex(1, np.nan), complex(1, np.nan))
  326. self._test_not_equal(complex(1, np.nan), complex(1, 2))
  327. self._test_not_equal(complex(np.nan, 1), complex(1, np.nan))
  328. self._test_not_equal(complex(np.nan, np.inf), complex(np.nan, 2))
  329. def test_negative_zero(self):
  330. self._test_not_equal(ncu.PZERO, ncu.NZERO)
  331. def test_complex(self):
  332. x = np.array([complex(1, 2), complex(1, np.nan)])
  333. y = np.array([complex(1, 2), complex(1, 2)])
  334. self._assert_func(x, x)
  335. self._test_not_equal(x, y)
  336. def test_object(self):
  337. # gh-12942
  338. import datetime
  339. a = np.array([datetime.datetime(2000, 1, 1),
  340. datetime.datetime(2000, 1, 2)])
  341. self._test_not_equal(a, a[::-1])
  342. class TestArrayAlmostEqual(_GenericTest):
  343. def setup_method(self):
  344. self._assert_func = assert_array_almost_equal
  345. def test_closeness(self):
  346. # Note that in the course of time we ended up with
  347. # `abs(x - y) < 1.5 * 10**(-decimal)`
  348. # instead of the previously documented
  349. # `abs(x - y) < 0.5 * 10**(-decimal)`
  350. # so this check serves to preserve the wrongness.
  351. # test scalars
  352. expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
  353. 'Max absolute difference among violations: 1.5\n'
  354. 'Max relative difference among violations: inf')
  355. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  356. self._assert_func(1.5, 0.0, decimal=0)
  357. # test arrays
  358. self._assert_func([1.499999], [0.0], decimal=0)
  359. expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
  360. 'Max absolute difference among violations: 1.5\n'
  361. 'Max relative difference among violations: inf')
  362. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  363. self._assert_func([1.5], [0.0], decimal=0)
  364. a = [1.4999999, 0.00003]
  365. b = [1.49999991, 0]
  366. expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
  367. 'Max absolute difference among violations: 3.e-05\n'
  368. 'Max relative difference among violations: inf')
  369. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  370. self._assert_func(a, b, decimal=7)
  371. expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
  372. 'Max absolute difference among violations: 3.e-05\n'
  373. 'Max relative difference among violations: 1.')
  374. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  375. self._assert_func(b, a, decimal=7)
  376. def test_simple(self):
  377. x = np.array([1234.2222])
  378. y = np.array([1234.2223])
  379. self._assert_func(x, y, decimal=3)
  380. self._assert_func(x, y, decimal=4)
  381. expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
  382. 'Max absolute difference among violations: '
  383. '1.e-04\n'
  384. 'Max relative difference among violations: '
  385. '8.10226812e-08')
  386. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  387. self._assert_func(x, y, decimal=5)
  388. def test_array_vs_scalar(self):
  389. a = [5498.42354, 849.54345, 0.00]
  390. b = 5498.42354
  391. expected_msg = ('Mismatched elements: 2 / 3 (66.7%)\n'
  392. 'Max absolute difference among violations: '
  393. '5498.42354\n'
  394. 'Max relative difference among violations: 1.')
  395. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  396. self._assert_func(a, b, decimal=9)
  397. expected_msg = ('Mismatched elements: 2 / 3 (66.7%)\n'
  398. 'Max absolute difference among violations: '
  399. '5498.42354\n'
  400. 'Max relative difference among violations: 5.4722099')
  401. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  402. self._assert_func(b, a, decimal=9)
  403. a = [5498.42354, 0.00]
  404. expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
  405. 'Max absolute difference among violations: '
  406. '5498.42354\n'
  407. 'Max relative difference among violations: inf')
  408. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  409. self._assert_func(b, a, decimal=7)
  410. b = 0
  411. expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
  412. 'Max absolute difference among violations: '
  413. '5498.42354\n'
  414. 'Max relative difference among violations: inf')
  415. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  416. self._assert_func(a, b, decimal=7)
  417. def test_nan(self):
  418. anan = np.array([np.nan])
  419. aone = np.array([1])
  420. ainf = np.array([np.inf])
  421. self._assert_func(anan, anan)
  422. assert_raises(AssertionError,
  423. lambda: self._assert_func(anan, aone))
  424. assert_raises(AssertionError,
  425. lambda: self._assert_func(anan, ainf))
  426. assert_raises(AssertionError,
  427. lambda: self._assert_func(ainf, anan))
  428. def test_inf(self):
  429. a = np.array([[1., 2.], [3., 4.]])
  430. b = a.copy()
  431. a[0, 0] = np.inf
  432. assert_raises(AssertionError,
  433. lambda: self._assert_func(a, b))
  434. b[0, 0] = -np.inf
  435. assert_raises(AssertionError,
  436. lambda: self._assert_func(a, b))
  437. def test_subclass(self):
  438. a = np.array([[1., 2.], [3., 4.]])
  439. b = np.ma.masked_array([[1., 2.], [0., 4.]],
  440. [[False, False], [True, False]])
  441. self._assert_func(a, b)
  442. self._assert_func(b, a)
  443. self._assert_func(b, b)
  444. # Test fully masked as well (see gh-11123).
  445. a = np.ma.MaskedArray(3.5, mask=True)
  446. b = np.array([3., 4., 6.5])
  447. self._test_equal(a, b)
  448. self._test_equal(b, a)
  449. a = np.ma.masked
  450. b = np.array([3., 4., 6.5])
  451. self._test_equal(a, b)
  452. self._test_equal(b, a)
  453. a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, True, True])
  454. b = np.array([1., 2., 3.])
  455. self._test_equal(a, b)
  456. self._test_equal(b, a)
  457. a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, True, True])
  458. b = np.array(1.)
  459. self._test_equal(a, b)
  460. self._test_equal(b, a)
  461. def test_subclass_2(self):
  462. # While we cannot guarantee testing functions will always work for
  463. # subclasses, the tests should ideally rely only on subclasses having
  464. # comparison operators, not on them being able to store booleans
  465. # (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
  466. class MyArray(np.ndarray):
  467. def __eq__(self, other):
  468. return super().__eq__(other).view(np.ndarray)
  469. def __lt__(self, other):
  470. return super().__lt__(other).view(np.ndarray)
  471. def all(self, *args, **kwargs):
  472. return all(self)
  473. a = np.array([1., 2.]).view(MyArray)
  474. self._assert_func(a, a)
  475. z = np.array([True, True]).view(MyArray)
  476. all(z)
  477. b = np.array([1., 202]).view(MyArray)
  478. expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
  479. 'Max absolute difference among violations: 200.\n'
  480. 'Max relative difference among violations: 0.99009')
  481. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  482. self._assert_func(a, b)
  483. def test_subclass_that_cannot_be_bool(self):
  484. # While we cannot guarantee testing functions will always work for
  485. # subclasses, the tests should ideally rely only on subclasses having
  486. # comparison operators, not on them being able to store booleans
  487. # (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
  488. class MyArray(np.ndarray):
  489. def __eq__(self, other):
  490. return super().__eq__(other).view(np.ndarray)
  491. def __lt__(self, other):
  492. return super().__lt__(other).view(np.ndarray)
  493. def all(self, *args, **kwargs):
  494. raise NotImplementedError
  495. a = np.array([1., 2.]).view(MyArray)
  496. self._assert_func(a, a)
  497. class TestAlmostEqual(_GenericTest):
  498. def setup_method(self):
  499. self._assert_func = assert_almost_equal
  500. def test_closeness(self):
  501. # Note that in the course of time we ended up with
  502. # `abs(x - y) < 1.5 * 10**(-decimal)`
  503. # instead of the previously documented
  504. # `abs(x - y) < 0.5 * 10**(-decimal)`
  505. # so this check serves to preserve the wrongness.
  506. # test scalars
  507. self._assert_func(1.499999, 0.0, decimal=0)
  508. assert_raises(AssertionError,
  509. lambda: self._assert_func(1.5, 0.0, decimal=0))
  510. # test arrays
  511. self._assert_func([1.499999], [0.0], decimal=0)
  512. assert_raises(AssertionError,
  513. lambda: self._assert_func([1.5], [0.0], decimal=0))
  514. def test_nan_item(self):
  515. self._assert_func(np.nan, np.nan)
  516. assert_raises(AssertionError,
  517. lambda: self._assert_func(np.nan, 1))
  518. assert_raises(AssertionError,
  519. lambda: self._assert_func(np.nan, np.inf))
  520. assert_raises(AssertionError,
  521. lambda: self._assert_func(np.inf, np.nan))
  522. def test_inf_item(self):
  523. self._assert_func(np.inf, np.inf)
  524. self._assert_func(-np.inf, -np.inf)
  525. assert_raises(AssertionError,
  526. lambda: self._assert_func(np.inf, 1))
  527. assert_raises(AssertionError,
  528. lambda: self._assert_func(-np.inf, np.inf))
  529. def test_simple_item(self):
  530. self._test_not_equal(1, 2)
  531. def test_complex_item(self):
  532. self._assert_func(complex(1, 2), complex(1, 2))
  533. self._assert_func(complex(1, np.nan), complex(1, np.nan))
  534. self._assert_func(complex(np.inf, np.nan), complex(np.inf, np.nan))
  535. self._test_not_equal(complex(1, np.nan), complex(1, 2))
  536. self._test_not_equal(complex(np.nan, 1), complex(1, np.nan))
  537. self._test_not_equal(complex(np.nan, np.inf), complex(np.nan, 2))
  538. def test_complex(self):
  539. x = np.array([complex(1, 2), complex(1, np.nan)])
  540. z = np.array([complex(1, 2), complex(np.nan, 1)])
  541. y = np.array([complex(1, 2), complex(1, 2)])
  542. self._assert_func(x, x)
  543. self._test_not_equal(x, y)
  544. self._test_not_equal(x, z)
  545. def test_error_message(self):
  546. """Check the message is formatted correctly for the decimal value.
  547. Also check the message when input includes inf or nan (gh12200)"""
  548. x = np.array([1.00000000001, 2.00000000002, 3.00003])
  549. y = np.array([1.00000000002, 2.00000000003, 3.00004])
  550. # Test with a different amount of decimal digits
  551. expected_msg = ('Mismatched elements: 3 / 3 (100%)\n'
  552. 'Max absolute difference among violations: 1.e-05\n'
  553. 'Max relative difference among violations: '
  554. '3.33328889e-06\n'
  555. ' ACTUAL: array([1.00000000001, '
  556. '2.00000000002, '
  557. '3.00003 ])\n'
  558. ' DESIRED: array([1.00000000002, 2.00000000003, '
  559. '3.00004 ])')
  560. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  561. self._assert_func(x, y, decimal=12)
  562. # With the default value of decimal digits, only the 3rd element
  563. # differs. Note that we only check for the formatting of the arrays
  564. # themselves.
  565. expected_msg = ('Mismatched elements: 1 / 3 (33.3%)\n'
  566. 'Max absolute difference among violations: 1.e-05\n'
  567. 'Max relative difference among violations: '
  568. '3.33328889e-06\n'
  569. ' ACTUAL: array([1. , 2. , 3.00003])\n'
  570. ' DESIRED: array([1. , 2. , 3.00004])')
  571. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  572. self._assert_func(x, y)
  573. # Check the error message when input includes inf
  574. x = np.array([np.inf, 0])
  575. y = np.array([np.inf, 1])
  576. expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
  577. 'Max absolute difference among violations: 1.\n'
  578. 'Max relative difference among violations: 1.\n'
  579. ' ACTUAL: array([inf, 0.])\n'
  580. ' DESIRED: array([inf, 1.])')
  581. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  582. self._assert_func(x, y)
  583. # Check the error message when dividing by zero
  584. x = np.array([1, 2])
  585. y = np.array([0, 0])
  586. expected_msg = ('Mismatched elements: 2 / 2 (100%)\n'
  587. 'Max absolute difference among violations: 2\n'
  588. 'Max relative difference among violations: inf')
  589. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  590. self._assert_func(x, y)
  591. def test_error_message_2(self):
  592. """Check the message is formatted correctly """
  593. """when either x or y is a scalar."""
  594. x = 2
  595. y = np.ones(20)
  596. expected_msg = ('Mismatched elements: 20 / 20 (100%)\n'
  597. 'Max absolute difference among violations: 1.\n'
  598. 'Max relative difference among violations: 1.')
  599. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  600. self._assert_func(x, y)
  601. y = 2
  602. x = np.ones(20)
  603. expected_msg = ('Mismatched elements: 20 / 20 (100%)\n'
  604. 'Max absolute difference among violations: 1.\n'
  605. 'Max relative difference among violations: 0.5')
  606. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  607. self._assert_func(x, y)
  608. def test_subclass_that_cannot_be_bool(self):
  609. # While we cannot guarantee testing functions will always work for
  610. # subclasses, the tests should ideally rely only on subclasses having
  611. # comparison operators, not on them being able to store booleans
  612. # (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
  613. class MyArray(np.ndarray):
  614. def __eq__(self, other):
  615. return super().__eq__(other).view(np.ndarray)
  616. def __lt__(self, other):
  617. return super().__lt__(other).view(np.ndarray)
  618. def all(self, *args, **kwargs):
  619. raise NotImplementedError
  620. a = np.array([1., 2.]).view(MyArray)
  621. self._assert_func(a, a)
  622. class TestApproxEqual:
  623. def setup_method(self):
  624. self._assert_func = assert_approx_equal
  625. def test_simple_0d_arrays(self):
  626. x = np.array(1234.22)
  627. y = np.array(1234.23)
  628. self._assert_func(x, y, significant=5)
  629. self._assert_func(x, y, significant=6)
  630. assert_raises(AssertionError,
  631. lambda: self._assert_func(x, y, significant=7))
  632. def test_simple_items(self):
  633. x = 1234.22
  634. y = 1234.23
  635. self._assert_func(x, y, significant=4)
  636. self._assert_func(x, y, significant=5)
  637. self._assert_func(x, y, significant=6)
  638. assert_raises(AssertionError,
  639. lambda: self._assert_func(x, y, significant=7))
  640. def test_nan_array(self):
  641. anan = np.array(np.nan)
  642. aone = np.array(1)
  643. ainf = np.array(np.inf)
  644. self._assert_func(anan, anan)
  645. assert_raises(AssertionError, lambda: self._assert_func(anan, aone))
  646. assert_raises(AssertionError, lambda: self._assert_func(anan, ainf))
  647. assert_raises(AssertionError, lambda: self._assert_func(ainf, anan))
  648. def test_nan_items(self):
  649. anan = np.array(np.nan)
  650. aone = np.array(1)
  651. ainf = np.array(np.inf)
  652. self._assert_func(anan, anan)
  653. assert_raises(AssertionError, lambda: self._assert_func(anan, aone))
  654. assert_raises(AssertionError, lambda: self._assert_func(anan, ainf))
  655. assert_raises(AssertionError, lambda: self._assert_func(ainf, anan))
  656. class TestArrayAssertLess:
  657. def setup_method(self):
  658. self._assert_func = assert_array_less
  659. def test_simple_arrays(self):
  660. x = np.array([1.1, 2.2])
  661. y = np.array([1.2, 2.3])
  662. self._assert_func(x, y)
  663. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  664. y = np.array([1.0, 2.3])
  665. assert_raises(AssertionError, lambda: self._assert_func(x, y))
  666. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  667. a = np.array([1, 3, 6, 20])
  668. b = np.array([2, 4, 6, 8])
  669. expected_msg = ('Mismatched elements: 2 / 4 (50%)\n'
  670. 'Max absolute difference among violations: 12\n'
  671. 'Max relative difference among violations: 1.5')
  672. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  673. self._assert_func(a, b)
  674. def test_rank2(self):
  675. x = np.array([[1.1, 2.2], [3.3, 4.4]])
  676. y = np.array([[1.2, 2.3], [3.4, 4.5]])
  677. self._assert_func(x, y)
  678. expected_msg = ('Mismatched elements: 4 / 4 (100%)\n'
  679. 'Max absolute difference among violations: 0.1\n'
  680. 'Max relative difference among violations: 0.09090909')
  681. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  682. self._assert_func(y, x)
  683. y = np.array([[1.0, 2.3], [3.4, 4.5]])
  684. assert_raises(AssertionError, lambda: self._assert_func(x, y))
  685. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  686. def test_rank3(self):
  687. x = np.ones(shape=(2, 2, 2))
  688. y = np.ones(shape=(2, 2, 2))+1
  689. self._assert_func(x, y)
  690. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  691. y[0, 0, 0] = 0
  692. expected_msg = ('Mismatched elements: 1 / 8 (12.5%)\n'
  693. 'Max absolute difference among violations: 1.\n'
  694. 'Max relative difference among violations: inf')
  695. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  696. self._assert_func(x, y)
  697. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  698. def test_simple_items(self):
  699. x = 1.1
  700. y = 2.2
  701. self._assert_func(x, y)
  702. expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
  703. 'Max absolute difference among violations: 1.1\n'
  704. 'Max relative difference among violations: 1.')
  705. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  706. self._assert_func(y, x)
  707. y = np.array([2.2, 3.3])
  708. self._assert_func(x, y)
  709. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  710. y = np.array([1.0, 3.3])
  711. assert_raises(AssertionError, lambda: self._assert_func(x, y))
  712. def test_simple_items_and_array(self):
  713. x = np.array([[621.345454, 390.5436, 43.54657, 626.4535],
  714. [54.54, 627.3399, 13., 405.5435],
  715. [543.545, 8.34, 91.543, 333.3]])
  716. y = 627.34
  717. self._assert_func(x, y)
  718. y = 8.339999
  719. self._assert_func(y, x)
  720. x = np.array([[3.4536, 2390.5436, 435.54657, 324525.4535],
  721. [5449.54, 999090.54, 130303.54, 405.5435],
  722. [543.545, 8.34, 91.543, 999090.53999]])
  723. y = 999090.54
  724. expected_msg = ('Mismatched elements: 1 / 12 (8.33%)\n'
  725. 'Max absolute difference among violations: 0.\n'
  726. 'Max relative difference among violations: 0.')
  727. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  728. self._assert_func(x, y)
  729. expected_msg = ('Mismatched elements: 12 / 12 (100%)\n'
  730. 'Max absolute difference among violations: '
  731. '999087.0864\n'
  732. 'Max relative difference among violations: '
  733. '289288.5934676')
  734. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  735. self._assert_func(y, x)
  736. def test_zeroes(self):
  737. x = np.array([546456., 0, 15.455])
  738. y = np.array(87654.)
  739. expected_msg = ('Mismatched elements: 1 / 3 (33.3%)\n'
  740. 'Max absolute difference among violations: 458802.\n'
  741. 'Max relative difference among violations: 5.23423917')
  742. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  743. self._assert_func(x, y)
  744. expected_msg = ('Mismatched elements: 2 / 3 (66.7%)\n'
  745. 'Max absolute difference among violations: 87654.\n'
  746. 'Max relative difference among violations: '
  747. '5670.5626011')
  748. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  749. self._assert_func(y, x)
  750. y = 0
  751. expected_msg = ('Mismatched elements: 3 / 3 (100%)\n'
  752. 'Max absolute difference among violations: 546456.\n'
  753. 'Max relative difference among violations: inf')
  754. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  755. self._assert_func(x, y)
  756. expected_msg = ('Mismatched elements: 1 / 3 (33.3%)\n'
  757. 'Max absolute difference among violations: 0.\n'
  758. 'Max relative difference among violations: inf')
  759. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  760. self._assert_func(y, x)
  761. def test_nan_noncompare(self):
  762. anan = np.array(np.nan)
  763. aone = np.array(1)
  764. ainf = np.array(np.inf)
  765. self._assert_func(anan, anan)
  766. assert_raises(AssertionError, lambda: self._assert_func(aone, anan))
  767. assert_raises(AssertionError, lambda: self._assert_func(anan, aone))
  768. assert_raises(AssertionError, lambda: self._assert_func(anan, ainf))
  769. assert_raises(AssertionError, lambda: self._assert_func(ainf, anan))
  770. def test_nan_noncompare_array(self):
  771. x = np.array([1.1, 2.2, 3.3])
  772. anan = np.array(np.nan)
  773. assert_raises(AssertionError, lambda: self._assert_func(x, anan))
  774. assert_raises(AssertionError, lambda: self._assert_func(anan, x))
  775. x = np.array([1.1, 2.2, np.nan])
  776. assert_raises(AssertionError, lambda: self._assert_func(x, anan))
  777. assert_raises(AssertionError, lambda: self._assert_func(anan, x))
  778. y = np.array([1.0, 2.0, np.nan])
  779. self._assert_func(y, x)
  780. assert_raises(AssertionError, lambda: self._assert_func(x, y))
  781. def test_inf_compare(self):
  782. aone = np.array(1)
  783. ainf = np.array(np.inf)
  784. self._assert_func(aone, ainf)
  785. self._assert_func(-ainf, aone)
  786. self._assert_func(-ainf, ainf)
  787. assert_raises(AssertionError, lambda: self._assert_func(ainf, aone))
  788. assert_raises(AssertionError, lambda: self._assert_func(aone, -ainf))
  789. assert_raises(AssertionError, lambda: self._assert_func(ainf, ainf))
  790. assert_raises(AssertionError, lambda: self._assert_func(ainf, -ainf))
  791. assert_raises(AssertionError, lambda: self._assert_func(-ainf, -ainf))
  792. def test_inf_compare_array(self):
  793. x = np.array([1.1, 2.2, np.inf])
  794. ainf = np.array(np.inf)
  795. assert_raises(AssertionError, lambda: self._assert_func(x, ainf))
  796. assert_raises(AssertionError, lambda: self._assert_func(ainf, x))
  797. assert_raises(AssertionError, lambda: self._assert_func(x, -ainf))
  798. assert_raises(AssertionError, lambda: self._assert_func(-x, -ainf))
  799. assert_raises(AssertionError, lambda: self._assert_func(-ainf, -x))
  800. self._assert_func(-ainf, x)
  801. def test_strict(self):
  802. """Test the behavior of the `strict` option."""
  803. x = np.zeros(3)
  804. y = np.ones(())
  805. self._assert_func(x, y)
  806. with pytest.raises(AssertionError):
  807. self._assert_func(x, y, strict=True)
  808. y = np.broadcast_to(y, x.shape)
  809. self._assert_func(x, y)
  810. with pytest.raises(AssertionError):
  811. self._assert_func(x, y.astype(np.float32), strict=True)
  812. class TestWarns:
  813. def test_warn(self):
  814. def f():
  815. warnings.warn("yo")
  816. return 3
  817. before_filters = sys.modules['warnings'].filters[:]
  818. assert_equal(assert_warns(UserWarning, f), 3)
  819. after_filters = sys.modules['warnings'].filters
  820. assert_raises(AssertionError, assert_no_warnings, f)
  821. assert_equal(assert_no_warnings(lambda x: x, 1), 1)
  822. # Check that the warnings state is unchanged
  823. assert_equal(before_filters, after_filters,
  824. "assert_warns does not preserver warnings state")
  825. def test_context_manager(self):
  826. before_filters = sys.modules['warnings'].filters[:]
  827. with assert_warns(UserWarning):
  828. warnings.warn("yo")
  829. after_filters = sys.modules['warnings'].filters
  830. def no_warnings():
  831. with assert_no_warnings():
  832. warnings.warn("yo")
  833. assert_raises(AssertionError, no_warnings)
  834. assert_equal(before_filters, after_filters,
  835. "assert_warns does not preserver warnings state")
  836. def test_args(self):
  837. def f(a=0, b=1):
  838. warnings.warn("yo")
  839. return a + b
  840. assert assert_warns(UserWarning, f, b=20) == 20
  841. with pytest.raises(RuntimeError) as exc:
  842. # assert_warns cannot do regexp matching, use pytest.warns
  843. with assert_warns(UserWarning, match="A"):
  844. warnings.warn("B", UserWarning)
  845. assert "assert_warns" in str(exc)
  846. assert "pytest.warns" in str(exc)
  847. with pytest.raises(RuntimeError) as exc:
  848. # assert_warns cannot do regexp matching, use pytest.warns
  849. with assert_warns(UserWarning, wrong="A"):
  850. warnings.warn("B", UserWarning)
  851. assert "assert_warns" in str(exc)
  852. assert "pytest.warns" not in str(exc)
  853. def test_warn_wrong_warning(self):
  854. def f():
  855. warnings.warn("yo", DeprecationWarning)
  856. failed = False
  857. with warnings.catch_warnings():
  858. warnings.simplefilter("error", DeprecationWarning)
  859. try:
  860. # Should raise a DeprecationWarning
  861. assert_warns(UserWarning, f)
  862. failed = True
  863. except DeprecationWarning:
  864. pass
  865. if failed:
  866. raise AssertionError("wrong warning caught by assert_warn")
  867. class TestAssertAllclose:
  868. def test_simple(self):
  869. x = 1e-3
  870. y = 1e-9
  871. assert_allclose(x, y, atol=1)
  872. assert_raises(AssertionError, assert_allclose, x, y)
  873. expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
  874. 'Max absolute difference among violations: 0.001\n'
  875. 'Max relative difference among violations: 999999.')
  876. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  877. assert_allclose(x, y)
  878. z = 0
  879. expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
  880. 'Max absolute difference among violations: 1.e-09\n'
  881. 'Max relative difference among violations: inf')
  882. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  883. assert_allclose(y, z)
  884. expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
  885. 'Max absolute difference among violations: 1.e-09\n'
  886. 'Max relative difference among violations: 1.')
  887. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  888. assert_allclose(z, y)
  889. a = np.array([x, y, x, y])
  890. b = np.array([x, y, x, x])
  891. assert_allclose(a, b, atol=1)
  892. assert_raises(AssertionError, assert_allclose, a, b)
  893. b[-1] = y * (1 + 1e-8)
  894. assert_allclose(a, b)
  895. assert_raises(AssertionError, assert_allclose, a, b, rtol=1e-9)
  896. assert_allclose(6, 10, rtol=0.5)
  897. assert_raises(AssertionError, assert_allclose, 10, 6, rtol=0.5)
  898. b = np.array([x, y, x, x])
  899. c = np.array([x, y, x, z])
  900. expected_msg = ('Mismatched elements: 1 / 4 (25%)\n'
  901. 'Max absolute difference among violations: 0.001\n'
  902. 'Max relative difference among violations: inf')
  903. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  904. assert_allclose(b, c)
  905. expected_msg = ('Mismatched elements: 1 / 4 (25%)\n'
  906. 'Max absolute difference among violations: 0.001\n'
  907. 'Max relative difference among violations: 1.')
  908. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  909. assert_allclose(c, b)
  910. def test_min_int(self):
  911. a = np.array([np.iinfo(np.int_).min], dtype=np.int_)
  912. # Should not raise:
  913. assert_allclose(a, a)
  914. def test_report_fail_percentage(self):
  915. a = np.array([1, 1, 1, 1])
  916. b = np.array([1, 1, 1, 2])
  917. expected_msg = ('Mismatched elements: 1 / 4 (25%)\n'
  918. 'Max absolute difference among violations: 1\n'
  919. 'Max relative difference among violations: 0.5')
  920. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  921. assert_allclose(a, b)
  922. def test_equal_nan(self):
  923. a = np.array([np.nan])
  924. b = np.array([np.nan])
  925. # Should not raise:
  926. assert_allclose(a, b, equal_nan=True)
  927. def test_not_equal_nan(self):
  928. a = np.array([np.nan])
  929. b = np.array([np.nan])
  930. assert_raises(AssertionError, assert_allclose, a, b, equal_nan=False)
  931. def test_equal_nan_default(self):
  932. # Make sure equal_nan default behavior remains unchanged. (All
  933. # of these functions use assert_array_compare under the hood.)
  934. # None of these should raise.
  935. a = np.array([np.nan])
  936. b = np.array([np.nan])
  937. assert_array_equal(a, b)
  938. assert_array_almost_equal(a, b)
  939. assert_array_less(a, b)
  940. assert_allclose(a, b)
  941. def test_report_max_relative_error(self):
  942. a = np.array([0, 1])
  943. b = np.array([0, 2])
  944. expected_msg = 'Max relative difference among violations: 0.5'
  945. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  946. assert_allclose(a, b)
  947. def test_timedelta(self):
  948. # see gh-18286
  949. a = np.array([[1, 2, 3, "NaT"]], dtype="m8[ns]")
  950. assert_allclose(a, a)
  951. def test_error_message_unsigned(self):
  952. """Check the message is formatted correctly when overflow can occur
  953. (gh21768)"""
  954. # Ensure to test for potential overflow in the case of:
  955. # x - y
  956. # and
  957. # y - x
  958. x = np.asarray([0, 1, 8], dtype='uint8')
  959. y = np.asarray([4, 4, 4], dtype='uint8')
  960. expected_msg = 'Max absolute difference among violations: 4'
  961. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  962. assert_allclose(x, y, atol=3)
  963. def test_strict(self):
  964. """Test the behavior of the `strict` option."""
  965. x = np.ones(3)
  966. y = np.ones(())
  967. assert_allclose(x, y)
  968. with pytest.raises(AssertionError):
  969. assert_allclose(x, y, strict=True)
  970. assert_allclose(x, x)
  971. with pytest.raises(AssertionError):
  972. assert_allclose(x, x.astype(np.float32), strict=True)
  973. class TestArrayAlmostEqualNulp:
  974. def test_float64_pass(self):
  975. # The number of units of least precision
  976. # In this case, use a few places above the lowest level (ie nulp=1)
  977. nulp = 5
  978. x = np.linspace(-20, 20, 50, dtype=np.float64)
  979. x = 10**x
  980. x = np.r_[-x, x]
  981. # Addition
  982. eps = np.finfo(x.dtype).eps
  983. y = x + x*eps*nulp/2.
  984. assert_array_almost_equal_nulp(x, y, nulp)
  985. # Subtraction
  986. epsneg = np.finfo(x.dtype).epsneg
  987. y = x - x*epsneg*nulp/2.
  988. assert_array_almost_equal_nulp(x, y, nulp)
  989. def test_float64_fail(self):
  990. nulp = 5
  991. x = np.linspace(-20, 20, 50, dtype=np.float64)
  992. x = 10**x
  993. x = np.r_[-x, x]
  994. eps = np.finfo(x.dtype).eps
  995. y = x + x*eps*nulp*2.
  996. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  997. x, y, nulp)
  998. epsneg = np.finfo(x.dtype).epsneg
  999. y = x - x*epsneg*nulp*2.
  1000. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1001. x, y, nulp)
  1002. def test_float64_ignore_nan(self):
  1003. # Ignore ULP differences between various NAN's
  1004. # Note that MIPS may reverse quiet and signaling nans
  1005. # so we use the builtin version as a base.
  1006. offset = np.uint64(0xffffffff)
  1007. nan1_i64 = np.array(np.nan, dtype=np.float64).view(np.uint64)
  1008. nan2_i64 = nan1_i64 ^ offset # nan payload on MIPS is all ones.
  1009. nan1_f64 = nan1_i64.view(np.float64)
  1010. nan2_f64 = nan2_i64.view(np.float64)
  1011. assert_array_max_ulp(nan1_f64, nan2_f64, 0)
  1012. def test_float32_pass(self):
  1013. nulp = 5
  1014. x = np.linspace(-20, 20, 50, dtype=np.float32)
  1015. x = 10**x
  1016. x = np.r_[-x, x]
  1017. eps = np.finfo(x.dtype).eps
  1018. y = x + x*eps*nulp/2.
  1019. assert_array_almost_equal_nulp(x, y, nulp)
  1020. epsneg = np.finfo(x.dtype).epsneg
  1021. y = x - x*epsneg*nulp/2.
  1022. assert_array_almost_equal_nulp(x, y, nulp)
  1023. def test_float32_fail(self):
  1024. nulp = 5
  1025. x = np.linspace(-20, 20, 50, dtype=np.float32)
  1026. x = 10**x
  1027. x = np.r_[-x, x]
  1028. eps = np.finfo(x.dtype).eps
  1029. y = x + x*eps*nulp*2.
  1030. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1031. x, y, nulp)
  1032. epsneg = np.finfo(x.dtype).epsneg
  1033. y = x - x*epsneg*nulp*2.
  1034. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1035. x, y, nulp)
  1036. def test_float32_ignore_nan(self):
  1037. # Ignore ULP differences between various NAN's
  1038. # Note that MIPS may reverse quiet and signaling nans
  1039. # so we use the builtin version as a base.
  1040. offset = np.uint32(0xffff)
  1041. nan1_i32 = np.array(np.nan, dtype=np.float32).view(np.uint32)
  1042. nan2_i32 = nan1_i32 ^ offset # nan payload on MIPS is all ones.
  1043. nan1_f32 = nan1_i32.view(np.float32)
  1044. nan2_f32 = nan2_i32.view(np.float32)
  1045. assert_array_max_ulp(nan1_f32, nan2_f32, 0)
  1046. def test_float16_pass(self):
  1047. nulp = 5
  1048. x = np.linspace(-4, 4, 10, dtype=np.float16)
  1049. x = 10**x
  1050. x = np.r_[-x, x]
  1051. eps = np.finfo(x.dtype).eps
  1052. y = x + x*eps*nulp/2.
  1053. assert_array_almost_equal_nulp(x, y, nulp)
  1054. epsneg = np.finfo(x.dtype).epsneg
  1055. y = x - x*epsneg*nulp/2.
  1056. assert_array_almost_equal_nulp(x, y, nulp)
  1057. def test_float16_fail(self):
  1058. nulp = 5
  1059. x = np.linspace(-4, 4, 10, dtype=np.float16)
  1060. x = 10**x
  1061. x = np.r_[-x, x]
  1062. eps = np.finfo(x.dtype).eps
  1063. y = x + x*eps*nulp*2.
  1064. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1065. x, y, nulp)
  1066. epsneg = np.finfo(x.dtype).epsneg
  1067. y = x - x*epsneg*nulp*2.
  1068. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1069. x, y, nulp)
  1070. def test_float16_ignore_nan(self):
  1071. # Ignore ULP differences between various NAN's
  1072. # Note that MIPS may reverse quiet and signaling nans
  1073. # so we use the builtin version as a base.
  1074. offset = np.uint16(0xff)
  1075. nan1_i16 = np.array(np.nan, dtype=np.float16).view(np.uint16)
  1076. nan2_i16 = nan1_i16 ^ offset # nan payload on MIPS is all ones.
  1077. nan1_f16 = nan1_i16.view(np.float16)
  1078. nan2_f16 = nan2_i16.view(np.float16)
  1079. assert_array_max_ulp(nan1_f16, nan2_f16, 0)
  1080. def test_complex128_pass(self):
  1081. nulp = 5
  1082. x = np.linspace(-20, 20, 50, dtype=np.float64)
  1083. x = 10**x
  1084. x = np.r_[-x, x]
  1085. xi = x + x*1j
  1086. eps = np.finfo(x.dtype).eps
  1087. y = x + x*eps*nulp/2.
  1088. assert_array_almost_equal_nulp(xi, x + y*1j, nulp)
  1089. assert_array_almost_equal_nulp(xi, y + x*1j, nulp)
  1090. # The test condition needs to be at least a factor of sqrt(2) smaller
  1091. # because the real and imaginary parts both change
  1092. y = x + x*eps*nulp/4.
  1093. assert_array_almost_equal_nulp(xi, y + y*1j, nulp)
  1094. epsneg = np.finfo(x.dtype).epsneg
  1095. y = x - x*epsneg*nulp/2.
  1096. assert_array_almost_equal_nulp(xi, x + y*1j, nulp)
  1097. assert_array_almost_equal_nulp(xi, y + x*1j, nulp)
  1098. y = x - x*epsneg*nulp/4.
  1099. assert_array_almost_equal_nulp(xi, y + y*1j, nulp)
  1100. def test_complex128_fail(self):
  1101. nulp = 5
  1102. x = np.linspace(-20, 20, 50, dtype=np.float64)
  1103. x = 10**x
  1104. x = np.r_[-x, x]
  1105. xi = x + x*1j
  1106. eps = np.finfo(x.dtype).eps
  1107. y = x + x*eps*nulp*2.
  1108. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1109. xi, x + y*1j, nulp)
  1110. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1111. xi, y + x*1j, nulp)
  1112. # The test condition needs to be at least a factor of sqrt(2) smaller
  1113. # because the real and imaginary parts both change
  1114. y = x + x*eps*nulp
  1115. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1116. xi, y + y*1j, nulp)
  1117. epsneg = np.finfo(x.dtype).epsneg
  1118. y = x - x*epsneg*nulp*2.
  1119. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1120. xi, x + y*1j, nulp)
  1121. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1122. xi, y + x*1j, nulp)
  1123. y = x - x*epsneg*nulp
  1124. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1125. xi, y + y*1j, nulp)
  1126. def test_complex64_pass(self):
  1127. nulp = 5
  1128. x = np.linspace(-20, 20, 50, dtype=np.float32)
  1129. x = 10**x
  1130. x = np.r_[-x, x]
  1131. xi = x + x*1j
  1132. eps = np.finfo(x.dtype).eps
  1133. y = x + x*eps*nulp/2.
  1134. assert_array_almost_equal_nulp(xi, x + y*1j, nulp)
  1135. assert_array_almost_equal_nulp(xi, y + x*1j, nulp)
  1136. y = x + x*eps*nulp/4.
  1137. assert_array_almost_equal_nulp(xi, y + y*1j, nulp)
  1138. epsneg = np.finfo(x.dtype).epsneg
  1139. y = x - x*epsneg*nulp/2.
  1140. assert_array_almost_equal_nulp(xi, x + y*1j, nulp)
  1141. assert_array_almost_equal_nulp(xi, y + x*1j, nulp)
  1142. y = x - x*epsneg*nulp/4.
  1143. assert_array_almost_equal_nulp(xi, y + y*1j, nulp)
  1144. def test_complex64_fail(self):
  1145. nulp = 5
  1146. x = np.linspace(-20, 20, 50, dtype=np.float32)
  1147. x = 10**x
  1148. x = np.r_[-x, x]
  1149. xi = x + x*1j
  1150. eps = np.finfo(x.dtype).eps
  1151. y = x + x*eps*nulp*2.
  1152. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1153. xi, x + y*1j, nulp)
  1154. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1155. xi, y + x*1j, nulp)
  1156. y = x + x*eps*nulp
  1157. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1158. xi, y + y*1j, nulp)
  1159. epsneg = np.finfo(x.dtype).epsneg
  1160. y = x - x*epsneg*nulp*2.
  1161. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1162. xi, x + y*1j, nulp)
  1163. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1164. xi, y + x*1j, nulp)
  1165. y = x - x*epsneg*nulp
  1166. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1167. xi, y + y*1j, nulp)
  1168. class TestULP:
  1169. def test_equal(self):
  1170. x = np.random.randn(10)
  1171. assert_array_max_ulp(x, x, maxulp=0)
  1172. def test_single(self):
  1173. # Generate 1 + small deviation, check that adding eps gives a few UNL
  1174. x = np.ones(10).astype(np.float32)
  1175. x += 0.01 * np.random.randn(10).astype(np.float32)
  1176. eps = np.finfo(np.float32).eps
  1177. assert_array_max_ulp(x, x+eps, maxulp=20)
  1178. def test_double(self):
  1179. # Generate 1 + small deviation, check that adding eps gives a few UNL
  1180. x = np.ones(10).astype(np.float64)
  1181. x += 0.01 * np.random.randn(10).astype(np.float64)
  1182. eps = np.finfo(np.float64).eps
  1183. assert_array_max_ulp(x, x+eps, maxulp=200)
  1184. def test_inf(self):
  1185. for dt in [np.float32, np.float64]:
  1186. inf = np.array([np.inf]).astype(dt)
  1187. big = np.array([np.finfo(dt).max])
  1188. assert_array_max_ulp(inf, big, maxulp=200)
  1189. def test_nan(self):
  1190. # Test that nan is 'far' from small, tiny, inf, max and min
  1191. for dt in [np.float32, np.float64]:
  1192. if dt == np.float32:
  1193. maxulp = 1e6
  1194. else:
  1195. maxulp = 1e12
  1196. inf = np.array([np.inf]).astype(dt)
  1197. nan = np.array([np.nan]).astype(dt)
  1198. big = np.array([np.finfo(dt).max])
  1199. tiny = np.array([np.finfo(dt).tiny])
  1200. zero = np.array([0.0]).astype(dt)
  1201. nzero = np.array([-0.0]).astype(dt)
  1202. assert_raises(AssertionError,
  1203. lambda: assert_array_max_ulp(nan, inf,
  1204. maxulp=maxulp))
  1205. assert_raises(AssertionError,
  1206. lambda: assert_array_max_ulp(nan, big,
  1207. maxulp=maxulp))
  1208. assert_raises(AssertionError,
  1209. lambda: assert_array_max_ulp(nan, tiny,
  1210. maxulp=maxulp))
  1211. assert_raises(AssertionError,
  1212. lambda: assert_array_max_ulp(nan, zero,
  1213. maxulp=maxulp))
  1214. assert_raises(AssertionError,
  1215. lambda: assert_array_max_ulp(nan, nzero,
  1216. maxulp=maxulp))
  1217. class TestStringEqual:
  1218. def test_simple(self):
  1219. assert_string_equal("hello", "hello")
  1220. assert_string_equal("hello\nmultiline", "hello\nmultiline")
  1221. with pytest.raises(AssertionError) as exc_info:
  1222. assert_string_equal("foo\nbar", "hello\nbar")
  1223. msg = str(exc_info.value)
  1224. assert_equal(msg, "Differences in strings:\n- foo\n+ hello")
  1225. assert_raises(AssertionError,
  1226. lambda: assert_string_equal("foo", "hello"))
  1227. def test_regex(self):
  1228. assert_string_equal("a+*b", "a+*b")
  1229. assert_raises(AssertionError,
  1230. lambda: assert_string_equal("aaa", "a+b"))
  1231. def assert_warn_len_equal(mod, n_in_context):
  1232. try:
  1233. mod_warns = mod.__warningregistry__
  1234. except AttributeError:
  1235. # the lack of a __warningregistry__
  1236. # attribute means that no warning has
  1237. # occurred; this can be triggered in
  1238. # a parallel test scenario, while in
  1239. # a serial test scenario an initial
  1240. # warning (and therefore the attribute)
  1241. # are always created first
  1242. mod_warns = {}
  1243. num_warns = len(mod_warns)
  1244. if 'version' in mod_warns:
  1245. # Python 3 adds a 'version' entry to the registry,
  1246. # do not count it.
  1247. num_warns -= 1
  1248. assert_equal(num_warns, n_in_context)
  1249. def test_warn_len_equal_call_scenarios():
  1250. # assert_warn_len_equal is called under
  1251. # varying circumstances depending on serial
  1252. # vs. parallel test scenarios; this test
  1253. # simply aims to probe both code paths and
  1254. # check that no assertion is uncaught
  1255. # parallel scenario -- no warning issued yet
  1256. class mod:
  1257. pass
  1258. mod_inst = mod()
  1259. assert_warn_len_equal(mod=mod_inst,
  1260. n_in_context=0)
  1261. # serial test scenario -- the __warningregistry__
  1262. # attribute should be present
  1263. class mod:
  1264. def __init__(self):
  1265. self.__warningregistry__ = {'warning1': 1,
  1266. 'warning2': 2}
  1267. mod_inst = mod()
  1268. assert_warn_len_equal(mod=mod_inst,
  1269. n_in_context=2)
  1270. def _get_fresh_mod():
  1271. # Get this module, with warning registry empty
  1272. my_mod = sys.modules[__name__]
  1273. try:
  1274. my_mod.__warningregistry__.clear()
  1275. except AttributeError:
  1276. # will not have a __warningregistry__ unless warning has been
  1277. # raised in the module at some point
  1278. pass
  1279. return my_mod
  1280. def test_clear_and_catch_warnings():
  1281. # Initial state of module, no warnings
  1282. my_mod = _get_fresh_mod()
  1283. assert_equal(getattr(my_mod, '__warningregistry__', {}), {})
  1284. with clear_and_catch_warnings(modules=[my_mod]):
  1285. warnings.simplefilter('ignore')
  1286. warnings.warn('Some warning')
  1287. assert_equal(my_mod.__warningregistry__, {})
  1288. # Without specified modules, don't clear warnings during context.
  1289. # catch_warnings doesn't make an entry for 'ignore'.
  1290. with clear_and_catch_warnings():
  1291. warnings.simplefilter('ignore')
  1292. warnings.warn('Some warning')
  1293. assert_warn_len_equal(my_mod, 0)
  1294. # Manually adding two warnings to the registry:
  1295. my_mod.__warningregistry__ = {'warning1': 1,
  1296. 'warning2': 2}
  1297. # Confirm that specifying module keeps old warning, does not add new
  1298. with clear_and_catch_warnings(modules=[my_mod]):
  1299. warnings.simplefilter('ignore')
  1300. warnings.warn('Another warning')
  1301. assert_warn_len_equal(my_mod, 2)
  1302. # Another warning, no module spec it clears up registry
  1303. with clear_and_catch_warnings():
  1304. warnings.simplefilter('ignore')
  1305. warnings.warn('Another warning')
  1306. assert_warn_len_equal(my_mod, 0)
  1307. def test_suppress_warnings_module():
  1308. # Initial state of module, no warnings
  1309. my_mod = _get_fresh_mod()
  1310. assert_equal(getattr(my_mod, '__warningregistry__', {}), {})
  1311. def warn_other_module():
  1312. # Apply along axis is implemented in python; stacklevel=2 means
  1313. # we end up inside its module, not ours.
  1314. def warn(arr):
  1315. warnings.warn("Some warning 2", stacklevel=2)
  1316. return arr
  1317. np.apply_along_axis(warn, 0, [0])
  1318. # Test module based warning suppression:
  1319. assert_warn_len_equal(my_mod, 0)
  1320. with suppress_warnings() as sup:
  1321. sup.record(UserWarning)
  1322. # suppress warning from other module (may have .pyc ending),
  1323. # if apply_along_axis is moved, had to be changed.
  1324. sup.filter(module=np.lib._shape_base_impl)
  1325. warnings.warn("Some warning")
  1326. warn_other_module()
  1327. # Check that the suppression did test the file correctly (this module
  1328. # got filtered)
  1329. assert_equal(len(sup.log), 1)
  1330. assert_equal(sup.log[0].message.args[0], "Some warning")
  1331. assert_warn_len_equal(my_mod, 0)
  1332. sup = suppress_warnings()
  1333. # Will have to be changed if apply_along_axis is moved:
  1334. sup.filter(module=my_mod)
  1335. with sup:
  1336. warnings.warn('Some warning')
  1337. assert_warn_len_equal(my_mod, 0)
  1338. # And test repeat works:
  1339. sup.filter(module=my_mod)
  1340. with sup:
  1341. warnings.warn('Some warning')
  1342. assert_warn_len_equal(my_mod, 0)
  1343. # Without specified modules
  1344. with suppress_warnings():
  1345. warnings.simplefilter('ignore')
  1346. warnings.warn('Some warning')
  1347. assert_warn_len_equal(my_mod, 0)
  1348. def test_suppress_warnings_type():
  1349. # Initial state of module, no warnings
  1350. my_mod = _get_fresh_mod()
  1351. assert_equal(getattr(my_mod, '__warningregistry__', {}), {})
  1352. # Test module based warning suppression:
  1353. with suppress_warnings() as sup:
  1354. sup.filter(UserWarning)
  1355. warnings.warn('Some warning')
  1356. assert_warn_len_equal(my_mod, 0)
  1357. sup = suppress_warnings()
  1358. sup.filter(UserWarning)
  1359. with sup:
  1360. warnings.warn('Some warning')
  1361. assert_warn_len_equal(my_mod, 0)
  1362. # And test repeat works:
  1363. sup.filter(module=my_mod)
  1364. with sup:
  1365. warnings.warn('Some warning')
  1366. assert_warn_len_equal(my_mod, 0)
  1367. # Without specified modules
  1368. with suppress_warnings():
  1369. warnings.simplefilter('ignore')
  1370. warnings.warn('Some warning')
  1371. assert_warn_len_equal(my_mod, 0)
  1372. def test_suppress_warnings_decorate_no_record():
  1373. sup = suppress_warnings()
  1374. sup.filter(UserWarning)
  1375. @sup
  1376. def warn(category):
  1377. warnings.warn('Some warning', category)
  1378. with warnings.catch_warnings(record=True) as w:
  1379. warnings.simplefilter("always")
  1380. warn(UserWarning) # should be suppressed
  1381. warn(RuntimeWarning)
  1382. assert_equal(len(w), 1)
  1383. def test_suppress_warnings_record():
  1384. sup = suppress_warnings()
  1385. log1 = sup.record()
  1386. with sup:
  1387. log2 = sup.record(message='Some other warning 2')
  1388. sup.filter(message='Some warning')
  1389. warnings.warn('Some warning')
  1390. warnings.warn('Some other warning')
  1391. warnings.warn('Some other warning 2')
  1392. assert_equal(len(sup.log), 2)
  1393. assert_equal(len(log1), 1)
  1394. assert_equal(len(log2), 1)
  1395. assert_equal(log2[0].message.args[0], 'Some other warning 2')
  1396. # Do it again, with the same context to see if some warnings survived:
  1397. with sup:
  1398. log2 = sup.record(message='Some other warning 2')
  1399. sup.filter(message='Some warning')
  1400. warnings.warn('Some warning')
  1401. warnings.warn('Some other warning')
  1402. warnings.warn('Some other warning 2')
  1403. assert_equal(len(sup.log), 2)
  1404. assert_equal(len(log1), 1)
  1405. assert_equal(len(log2), 1)
  1406. assert_equal(log2[0].message.args[0], 'Some other warning 2')
  1407. # Test nested:
  1408. with suppress_warnings() as sup:
  1409. sup.record()
  1410. with suppress_warnings() as sup2:
  1411. sup2.record(message='Some warning')
  1412. warnings.warn('Some warning')
  1413. warnings.warn('Some other warning')
  1414. assert_equal(len(sup2.log), 1)
  1415. assert_equal(len(sup.log), 1)
  1416. def test_suppress_warnings_forwarding():
  1417. def warn_other_module():
  1418. # Apply along axis is implemented in python; stacklevel=2 means
  1419. # we end up inside its module, not ours.
  1420. def warn(arr):
  1421. warnings.warn("Some warning", stacklevel=2)
  1422. return arr
  1423. np.apply_along_axis(warn, 0, [0])
  1424. with suppress_warnings() as sup:
  1425. sup.record()
  1426. with suppress_warnings("always"):
  1427. for i in range(2):
  1428. warnings.warn("Some warning")
  1429. assert_equal(len(sup.log), 2)
  1430. with suppress_warnings() as sup:
  1431. sup.record()
  1432. with suppress_warnings("location"):
  1433. for i in range(2):
  1434. warnings.warn("Some warning")
  1435. warnings.warn("Some warning")
  1436. assert_equal(len(sup.log), 2)
  1437. with suppress_warnings() as sup:
  1438. sup.record()
  1439. with suppress_warnings("module"):
  1440. for i in range(2):
  1441. warnings.warn("Some warning")
  1442. warnings.warn("Some warning")
  1443. warn_other_module()
  1444. assert_equal(len(sup.log), 2)
  1445. with suppress_warnings() as sup:
  1446. sup.record()
  1447. with suppress_warnings("once"):
  1448. for i in range(2):
  1449. warnings.warn("Some warning")
  1450. warnings.warn("Some other warning")
  1451. warn_other_module()
  1452. assert_equal(len(sup.log), 2)
  1453. def test_tempdir():
  1454. with tempdir() as tdir:
  1455. fpath = os.path.join(tdir, 'tmp')
  1456. with open(fpath, 'w'):
  1457. pass
  1458. assert_(not os.path.isdir(tdir))
  1459. raised = False
  1460. try:
  1461. with tempdir() as tdir:
  1462. raise ValueError
  1463. except ValueError:
  1464. raised = True
  1465. assert_(raised)
  1466. assert_(not os.path.isdir(tdir))
  1467. def test_temppath():
  1468. with temppath() as fpath:
  1469. with open(fpath, 'w'):
  1470. pass
  1471. assert_(not os.path.isfile(fpath))
  1472. raised = False
  1473. try:
  1474. with temppath() as fpath:
  1475. raise ValueError
  1476. except ValueError:
  1477. raised = True
  1478. assert_(raised)
  1479. assert_(not os.path.isfile(fpath))
  1480. class my_cacw(clear_and_catch_warnings):
  1481. class_modules = (sys.modules[__name__],)
  1482. def test_clear_and_catch_warnings_inherit():
  1483. # Test can subclass and add default modules
  1484. my_mod = _get_fresh_mod()
  1485. with my_cacw():
  1486. warnings.simplefilter('ignore')
  1487. warnings.warn('Some warning')
  1488. assert_equal(my_mod.__warningregistry__, {})
  1489. @pytest.mark.skipif(not HAS_REFCOUNT, reason="Python lacks refcounts")
  1490. class TestAssertNoGcCycles:
  1491. """ Test assert_no_gc_cycles """
  1492. def test_passes(self):
  1493. def no_cycle():
  1494. b = []
  1495. b.append([])
  1496. return b
  1497. with assert_no_gc_cycles():
  1498. no_cycle()
  1499. assert_no_gc_cycles(no_cycle)
  1500. def test_asserts(self):
  1501. def make_cycle():
  1502. a = []
  1503. a.append(a)
  1504. a.append(a)
  1505. return a
  1506. with assert_raises(AssertionError):
  1507. with assert_no_gc_cycles():
  1508. make_cycle()
  1509. with assert_raises(AssertionError):
  1510. assert_no_gc_cycles(make_cycle)
  1511. @pytest.mark.slow
  1512. def test_fails(self):
  1513. """
  1514. Test that in cases where the garbage cannot be collected, we raise an
  1515. error, instead of hanging forever trying to clear it.
  1516. """
  1517. class ReferenceCycleInDel:
  1518. """
  1519. An object that not only contains a reference cycle, but creates new
  1520. cycles whenever it's garbage-collected and its __del__ runs
  1521. """
  1522. make_cycle = True
  1523. def __init__(self):
  1524. self.cycle = self
  1525. def __del__(self):
  1526. # break the current cycle so that `self` can be freed
  1527. self.cycle = None
  1528. if ReferenceCycleInDel.make_cycle:
  1529. # but create a new one so that the garbage collector has more
  1530. # work to do.
  1531. ReferenceCycleInDel()
  1532. try:
  1533. w = weakref.ref(ReferenceCycleInDel())
  1534. try:
  1535. with assert_raises(RuntimeError):
  1536. # this will be unable to get a baseline empty garbage
  1537. assert_no_gc_cycles(lambda: None)
  1538. except AssertionError:
  1539. # the above test is only necessary if the GC actually tried to free
  1540. # our object anyway, which python 2.7 does not.
  1541. if w() is not None:
  1542. pytest.skip("GC does not call __del__ on cyclic objects")
  1543. raise
  1544. finally:
  1545. # make sure that we stop creating reference cycles
  1546. ReferenceCycleInDel.make_cycle = False
  1547. @pytest.mark.parametrize('assert_func', [assert_array_equal,
  1548. assert_array_almost_equal])
  1549. def test_xy_rename(assert_func):
  1550. # Test that keywords `x` and `y` have been renamed to `actual` and
  1551. # `desired`, respectively. These tests and use of `_rename_parameter`
  1552. # decorator can be removed before the release of NumPy 2.2.0.
  1553. assert_func(1, 1)
  1554. assert_func(actual=1, desired=1)
  1555. assert_message = "Arrays are not..."
  1556. with pytest.raises(AssertionError, match=assert_message):
  1557. assert_func(1, 2)
  1558. with pytest.raises(AssertionError, match=assert_message):
  1559. assert_func(actual=1, desired=2)
  1560. dep_message = 'Use of keyword argument...'
  1561. with pytest.warns(DeprecationWarning, match=dep_message):
  1562. assert_func(x=1, desired=1)
  1563. with pytest.warns(DeprecationWarning, match=dep_message):
  1564. assert_func(1, y=1)
  1565. type_message = '...got multiple values for argument'
  1566. with (pytest.warns(DeprecationWarning, match=dep_message),
  1567. pytest.raises(TypeError, match=type_message)):
  1568. assert_func(1, x=1)
  1569. assert_func(1, 2, y=2)