test_utils.py 78 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123
  1. import itertools
  2. import os
  3. import re
  4. import sys
  5. import warnings
  6. import weakref
  7. import pytest
  8. import numpy as np
  9. import numpy._core._multiarray_umath as ncu
  10. from numpy.testing import (
  11. HAS_REFCOUNT,
  12. assert_,
  13. assert_allclose,
  14. assert_almost_equal,
  15. assert_approx_equal,
  16. assert_array_almost_equal,
  17. assert_array_almost_equal_nulp,
  18. assert_array_equal,
  19. assert_array_less,
  20. assert_array_max_ulp,
  21. assert_equal,
  22. assert_no_gc_cycles,
  23. assert_no_warnings,
  24. assert_raises,
  25. assert_string_equal,
  26. assert_warns,
  27. build_err_msg,
  28. clear_and_catch_warnings,
  29. suppress_warnings,
  30. tempdir,
  31. temppath,
  32. )
  33. class _GenericTest:
  34. def _assert_func(self, *args, **kwargs):
  35. pass
  36. def _test_equal(self, a, b):
  37. self._assert_func(a, b)
  38. def _test_not_equal(self, a, b):
  39. with assert_raises(AssertionError):
  40. self._assert_func(a, b)
  41. def test_array_rank1_eq(self):
  42. """Test two equal array of rank 1 are found equal."""
  43. a = np.array([1, 2])
  44. b = np.array([1, 2])
  45. self._test_equal(a, b)
  46. def test_array_rank1_noteq(self):
  47. """Test two different array of rank 1 are found not equal."""
  48. a = np.array([1, 2])
  49. b = np.array([2, 2])
  50. self._test_not_equal(a, b)
  51. def test_array_rank2_eq(self):
  52. """Test two equal array of rank 2 are found equal."""
  53. a = np.array([[1, 2], [3, 4]])
  54. b = np.array([[1, 2], [3, 4]])
  55. self._test_equal(a, b)
  56. def test_array_diffshape(self):
  57. """Test two arrays with different shapes are found not equal."""
  58. a = np.array([1, 2])
  59. b = np.array([[1, 2], [1, 2]])
  60. self._test_not_equal(a, b)
  61. def test_objarray(self):
  62. """Test object arrays."""
  63. a = np.array([1, 1], dtype=object)
  64. self._test_equal(a, 1)
  65. def test_array_likes(self):
  66. self._test_equal([1, 2, 3], (1, 2, 3))
  67. class TestArrayEqual(_GenericTest):
  68. def _assert_func(self, *args, **kwargs):
  69. assert_array_equal(*args, **kwargs)
  70. def test_generic_rank1(self):
  71. """Test rank 1 array for all dtypes."""
  72. def foo(t):
  73. a = np.empty(2, t)
  74. a.fill(1)
  75. b = a.copy()
  76. c = a.copy()
  77. c.fill(0)
  78. self._test_equal(a, b)
  79. self._test_not_equal(c, b)
  80. # Test numeric types and object
  81. for t in '?bhilqpBHILQPfdgFDG':
  82. foo(t)
  83. # Test strings
  84. for t in ['S1', 'U1']:
  85. foo(t)
  86. def test_0_ndim_array(self):
  87. x = np.array(473963742225900817127911193656584771)
  88. y = np.array(18535119325151578301457182298393896)
  89. with pytest.raises(AssertionError) as exc_info:
  90. self._assert_func(x, y)
  91. msg = str(exc_info.value)
  92. assert_('Mismatched elements: 1 / 1 (100%)\n'
  93. in msg)
  94. y = x
  95. self._assert_func(x, y)
  96. x = np.array(4395065348745.5643764887869876)
  97. y = np.array(0)
  98. expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
  99. 'Max absolute difference among violations: '
  100. '4.39506535e+12\n'
  101. 'Max relative difference among violations: inf\n')
  102. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  103. self._assert_func(x, y)
  104. x = y
  105. self._assert_func(x, y)
  106. def test_generic_rank3(self):
  107. """Test rank 3 array for all dtypes."""
  108. def foo(t):
  109. a = np.empty((4, 2, 3), t)
  110. a.fill(1)
  111. b = a.copy()
  112. c = a.copy()
  113. c.fill(0)
  114. self._test_equal(a, b)
  115. self._test_not_equal(c, b)
  116. # Test numeric types and object
  117. for t in '?bhilqpBHILQPfdgFDG':
  118. foo(t)
  119. # Test strings
  120. for t in ['S1', 'U1']:
  121. foo(t)
  122. def test_nan_array(self):
  123. """Test arrays with nan values in them."""
  124. a = np.array([1, 2, np.nan])
  125. b = np.array([1, 2, np.nan])
  126. self._test_equal(a, b)
  127. c = np.array([1, 2, 3])
  128. self._test_not_equal(c, b)
  129. def test_string_arrays(self):
  130. """Test two arrays with different shapes are found not equal."""
  131. a = np.array(['floupi', 'floupa'])
  132. b = np.array(['floupi', 'floupa'])
  133. self._test_equal(a, b)
  134. c = np.array(['floupipi', 'floupa'])
  135. self._test_not_equal(c, b)
  136. def test_recarrays(self):
  137. """Test record arrays."""
  138. a = np.empty(2, [('floupi', float), ('floupa', float)])
  139. a['floupi'] = [1, 2]
  140. a['floupa'] = [1, 2]
  141. b = a.copy()
  142. self._test_equal(a, b)
  143. c = np.empty(2, [('floupipi', float),
  144. ('floupi', float), ('floupa', float)])
  145. c['floupipi'] = a['floupi'].copy()
  146. c['floupa'] = a['floupa'].copy()
  147. with pytest.raises(TypeError):
  148. self._test_not_equal(c, b)
  149. def test_masked_nan_inf(self):
  150. # Regression test for gh-11121
  151. a = np.ma.MaskedArray([3., 4., 6.5], mask=[False, True, False])
  152. b = np.array([3., np.nan, 6.5])
  153. self._test_equal(a, b)
  154. self._test_equal(b, a)
  155. a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, False, False])
  156. b = np.array([np.inf, 4., 6.5])
  157. self._test_equal(a, b)
  158. self._test_equal(b, a)
  159. # Also provides test cases for gh-11121
  160. def test_masked_scalar(self):
  161. # Test masked scalar vs. plain/masked scalar
  162. for a_val, b_val, b_masked in itertools.product(
  163. [3., np.nan, np.inf],
  164. [3., 4., np.nan, np.inf, -np.inf],
  165. [False, True],
  166. ):
  167. a = np.ma.MaskedArray(a_val, mask=True)
  168. b = np.ma.MaskedArray(b_val, mask=True) if b_masked else np.array(b_val)
  169. self._test_equal(a, b)
  170. self._test_equal(b, a)
  171. # Test masked scalar vs. plain array
  172. for a_val, b_val in itertools.product(
  173. [3., np.nan, -np.inf],
  174. itertools.product([3., 4., np.nan, np.inf, -np.inf], repeat=2),
  175. ):
  176. a = np.ma.MaskedArray(a_val, mask=True)
  177. b = np.array(b_val)
  178. self._test_equal(a, b)
  179. self._test_equal(b, a)
  180. # Test masked scalar vs. masked array
  181. for a_val, b_val, b_mask in itertools.product(
  182. [3., np.nan, np.inf],
  183. itertools.product([3., 4., np.nan, np.inf, -np.inf], repeat=2),
  184. itertools.product([False, True], repeat=2),
  185. ):
  186. a = np.ma.MaskedArray(a_val, mask=True)
  187. b = np.ma.MaskedArray(b_val, mask=b_mask)
  188. self._test_equal(a, b)
  189. self._test_equal(b, a)
  190. def test_subclass_that_overrides_eq(self):
  191. # While we cannot guarantee testing functions will always work for
  192. # subclasses, the tests should ideally rely only on subclasses having
  193. # comparison operators, not on them being able to store booleans
  194. # (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
  195. class MyArray(np.ndarray):
  196. def __eq__(self, other):
  197. return bool(np.equal(self, other).all())
  198. def __ne__(self, other):
  199. return not self == other
  200. a = np.array([1., 2.]).view(MyArray)
  201. b = np.array([2., 3.]).view(MyArray)
  202. assert_(type(a == a), bool)
  203. assert_(a == a)
  204. assert_(a != b)
  205. self._test_equal(a, a)
  206. self._test_not_equal(a, b)
  207. self._test_not_equal(b, a)
  208. expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
  209. 'Max absolute difference among violations: 1.\n'
  210. 'Max relative difference among violations: 0.5')
  211. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  212. self._test_equal(a, b)
  213. c = np.array([0., 2.9]).view(MyArray)
  214. expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
  215. 'Max absolute difference among violations: 2.\n'
  216. 'Max relative difference among violations: inf')
  217. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  218. self._test_equal(b, c)
  219. def test_subclass_that_does_not_implement_npall(self):
  220. class MyArray(np.ndarray):
  221. def __array_function__(self, *args, **kwargs):
  222. return NotImplemented
  223. a = np.array([1., 2.]).view(MyArray)
  224. b = np.array([2., 3.]).view(MyArray)
  225. with assert_raises(TypeError):
  226. np.all(a)
  227. self._test_equal(a, a)
  228. self._test_not_equal(a, b)
  229. self._test_not_equal(b, a)
  230. def test_suppress_overflow_warnings(self):
  231. # Based on issue #18992
  232. with pytest.raises(AssertionError):
  233. with np.errstate(all="raise"):
  234. np.testing.assert_array_equal(
  235. np.array([1, 2, 3], np.float32),
  236. np.array([1, 1e-40, 3], np.float32))
  237. def test_array_vs_scalar_is_equal(self):
  238. """Test comparing an array with a scalar when all values are equal."""
  239. a = np.array([1., 1., 1.])
  240. b = 1.
  241. self._test_equal(a, b)
  242. def test_array_vs_array_not_equal(self):
  243. """Test comparing an array with a scalar when not all values equal."""
  244. a = np.array([34986, 545676, 439655, 563766])
  245. b = np.array([34986, 545676, 439655, 0])
  246. expected_msg = ('Mismatched elements: 1 / 4 (25%)\n'
  247. 'Mismatch at index:\n'
  248. ' [3]: 563766 (ACTUAL), 0 (DESIRED)\n'
  249. 'Max absolute difference among violations: 563766\n'
  250. 'Max relative difference among violations: inf')
  251. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  252. self._assert_func(a, b)
  253. a = np.array([34986, 545676, 439655.2, 563766])
  254. expected_msg = ('Mismatched elements: 2 / 4 (50%)\n'
  255. 'Mismatch at indices:\n'
  256. ' [2]: 439655.2 (ACTUAL), 439655 (DESIRED)\n'
  257. ' [3]: 563766.0 (ACTUAL), 0 (DESIRED)\n'
  258. 'Max absolute difference among violations: '
  259. '563766.\n'
  260. 'Max relative difference among violations: '
  261. '4.54902139e-07')
  262. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  263. self._assert_func(a, b)
  264. def test_array_vs_scalar_strict(self):
  265. """Test comparing an array with a scalar with strict option."""
  266. a = np.array([1., 1., 1.])
  267. b = 1.
  268. with pytest.raises(AssertionError):
  269. self._assert_func(a, b, strict=True)
  270. def test_array_vs_array_strict(self):
  271. """Test comparing two arrays with strict option."""
  272. a = np.array([1., 1., 1.])
  273. b = np.array([1., 1., 1.])
  274. self._assert_func(a, b, strict=True)
  275. def test_array_vs_float_array_strict(self):
  276. """Test comparing two arrays with strict option."""
  277. a = np.array([1, 1, 1])
  278. b = np.array([1., 1., 1.])
  279. with pytest.raises(AssertionError):
  280. self._assert_func(a, b, strict=True)
  281. class TestBuildErrorMessage:
  282. def test_build_err_msg_defaults(self):
  283. x = np.array([1.00001, 2.00002, 3.00003])
  284. y = np.array([1.00002, 2.00003, 3.00004])
  285. err_msg = 'There is a mismatch'
  286. a = build_err_msg([x, y], err_msg)
  287. b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array(['
  288. '1.00001, 2.00002, 3.00003])\n DESIRED: array([1.00002, '
  289. '2.00003, 3.00004])')
  290. assert_equal(a, b)
  291. def test_build_err_msg_no_verbose(self):
  292. x = np.array([1.00001, 2.00002, 3.00003])
  293. y = np.array([1.00002, 2.00003, 3.00004])
  294. err_msg = 'There is a mismatch'
  295. a = build_err_msg([x, y], err_msg, verbose=False)
  296. b = '\nItems are not equal: There is a mismatch'
  297. assert_equal(a, b)
  298. def test_build_err_msg_custom_names(self):
  299. x = np.array([1.00001, 2.00002, 3.00003])
  300. y = np.array([1.00002, 2.00003, 3.00004])
  301. err_msg = 'There is a mismatch'
  302. a = build_err_msg([x, y], err_msg, names=('FOO', 'BAR'))
  303. b = ('\nItems are not equal: There is a mismatch\n FOO: array(['
  304. '1.00001, 2.00002, 3.00003])\n BAR: array([1.00002, 2.00003, '
  305. '3.00004])')
  306. assert_equal(a, b)
  307. def test_build_err_msg_custom_precision(self):
  308. x = np.array([1.000000001, 2.00002, 3.00003])
  309. y = np.array([1.000000002, 2.00003, 3.00004])
  310. err_msg = 'There is a mismatch'
  311. a = build_err_msg([x, y], err_msg, precision=10)
  312. b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array(['
  313. '1.000000001, 2.00002 , 3.00003 ])\n DESIRED: array(['
  314. '1.000000002, 2.00003 , 3.00004 ])')
  315. assert_equal(a, b)
  316. class TestEqual(TestArrayEqual):
  317. def _assert_func(self, *args, **kwargs):
  318. assert_equal(*args, **kwargs)
  319. def test_nan_items(self):
  320. self._assert_func(np.nan, np.nan)
  321. self._assert_func([np.nan], [np.nan])
  322. self._test_not_equal(np.nan, [np.nan])
  323. self._test_not_equal(np.nan, 1)
  324. def test_inf_items(self):
  325. self._assert_func(np.inf, np.inf)
  326. self._assert_func([np.inf], [np.inf])
  327. self._test_not_equal(np.inf, [np.inf])
  328. def test_datetime(self):
  329. self._test_equal(
  330. np.datetime64("2017-01-01", "s"),
  331. np.datetime64("2017-01-01", "s")
  332. )
  333. self._test_equal(
  334. np.datetime64("2017-01-01", "s"),
  335. np.datetime64("2017-01-01", "m")
  336. )
  337. # gh-10081
  338. self._test_not_equal(
  339. np.datetime64("2017-01-01", "s"),
  340. np.datetime64("2017-01-02", "s")
  341. )
  342. self._test_not_equal(
  343. np.datetime64("2017-01-01", "s"),
  344. np.datetime64("2017-01-02", "m")
  345. )
  346. def test_nat_items(self):
  347. # not a datetime
  348. nadt_no_unit = np.datetime64("NaT")
  349. nadt_s = np.datetime64("NaT", "s")
  350. nadt_d = np.datetime64("NaT", "ns")
  351. # not a timedelta
  352. natd_no_unit = np.timedelta64("NaT")
  353. natd_s = np.timedelta64("NaT", "s")
  354. natd_d = np.timedelta64("NaT", "ns")
  355. dts = [nadt_no_unit, nadt_s, nadt_d]
  356. tds = [natd_no_unit, natd_s, natd_d]
  357. for a, b in itertools.product(dts, dts):
  358. self._assert_func(a, b)
  359. self._assert_func([a], [b])
  360. self._test_not_equal([a], b)
  361. for a, b in itertools.product(tds, tds):
  362. self._assert_func(a, b)
  363. self._assert_func([a], [b])
  364. self._test_not_equal([a], b)
  365. for a, b in itertools.product(tds, dts):
  366. self._test_not_equal(a, b)
  367. self._test_not_equal(a, [b])
  368. self._test_not_equal([a], [b])
  369. self._test_not_equal([a], np.datetime64("2017-01-01", "s"))
  370. self._test_not_equal([b], np.datetime64("2017-01-01", "s"))
  371. self._test_not_equal([a], np.timedelta64(123, "s"))
  372. self._test_not_equal([b], np.timedelta64(123, "s"))
  373. def test_non_numeric(self):
  374. self._assert_func('ab', 'ab')
  375. self._test_not_equal('ab', 'abb')
  376. def test_complex_item(self):
  377. self._assert_func(complex(1, 2), complex(1, 2))
  378. self._assert_func(complex(1, np.nan), complex(1, np.nan))
  379. self._test_not_equal(complex(1, np.nan), complex(1, 2))
  380. self._test_not_equal(complex(np.nan, 1), complex(1, np.nan))
  381. self._test_not_equal(complex(np.nan, np.inf), complex(np.nan, 2))
  382. def test_negative_zero(self):
  383. self._test_not_equal(ncu.PZERO, ncu.NZERO)
  384. def test_complex(self):
  385. x = np.array([complex(1, 2), complex(1, np.nan)])
  386. y = np.array([complex(1, 2), complex(1, 2)])
  387. self._assert_func(x, x)
  388. self._test_not_equal(x, y)
  389. def test_object(self):
  390. # gh-12942
  391. import datetime
  392. a = np.array([datetime.datetime(2000, 1, 1),
  393. datetime.datetime(2000, 1, 2)])
  394. self._test_not_equal(a, a[::-1])
  395. class TestArrayAlmostEqual(_GenericTest):
  396. def _assert_func(self, *args, **kwargs):
  397. assert_array_almost_equal(*args, **kwargs)
  398. def test_closeness(self):
  399. # Note that in the course of time we ended up with
  400. # `abs(x - y) < 1.5 * 10**(-decimal)`
  401. # instead of the previously documented
  402. # `abs(x - y) < 0.5 * 10**(-decimal)`
  403. # so this check serves to preserve the wrongness.
  404. # test scalars
  405. expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
  406. 'Max absolute difference among violations: 1.5\n'
  407. 'Max relative difference among violations: inf')
  408. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  409. self._assert_func(1.5, 0.0, decimal=0)
  410. # test arrays
  411. self._assert_func([1.499999], [0.0], decimal=0)
  412. expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
  413. 'Mismatch at index:\n'
  414. ' [0]: 1.5 (ACTUAL), 0.0 (DESIRED)\n'
  415. 'Max absolute difference among violations: 1.5\n'
  416. 'Max relative difference among violations: inf')
  417. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  418. self._assert_func([1.5], [0.0], decimal=0)
  419. a = [1.4999999, 0.00003]
  420. b = [1.49999991, 0]
  421. expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
  422. 'Mismatch at index:\n'
  423. ' [1]: 3e-05 (ACTUAL), 0.0 (DESIRED)\n'
  424. 'Max absolute difference among violations: 3.e-05\n'
  425. 'Max relative difference among violations: inf')
  426. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  427. self._assert_func(a, b, decimal=7)
  428. expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
  429. 'Mismatch at index:\n'
  430. ' [1]: 0.0 (ACTUAL), 3e-05 (DESIRED)\n'
  431. 'Max absolute difference among violations: 3.e-05\n'
  432. 'Max relative difference among violations: 1.')
  433. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  434. self._assert_func(b, a, decimal=7)
  435. def test_simple(self):
  436. x = np.array([1234.2222])
  437. y = np.array([1234.2223])
  438. self._assert_func(x, y, decimal=3)
  439. self._assert_func(x, y, decimal=4)
  440. expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
  441. 'Mismatch at index:\n'
  442. ' [0]: 1234.2222 (ACTUAL), 1234.2223 (DESIRED)\n'
  443. 'Max absolute difference among violations: '
  444. '1.e-04\n'
  445. 'Max relative difference among violations: '
  446. '8.10226812e-08')
  447. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  448. self._assert_func(x, y, decimal=5)
  449. def test_array_vs_scalar(self):
  450. a = [5498.42354, 849.54345, 0.00]
  451. b = 5498.42354
  452. expected_msg = ('Mismatched elements: 2 / 3 (66.7%)\n'
  453. 'Mismatch at indices:\n'
  454. ' [1]: 849.54345 (ACTUAL), 5498.42354 (DESIRED)\n'
  455. ' [2]: 0.0 (ACTUAL), 5498.42354 (DESIRED)\n'
  456. 'Max absolute difference among violations: '
  457. '5498.42354\n'
  458. 'Max relative difference among violations: 1.')
  459. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  460. self._assert_func(a, b, decimal=9)
  461. expected_msg = ('Mismatched elements: 2 / 3 (66.7%)\n'
  462. 'Mismatch at indices:\n'
  463. ' [1]: 5498.42354 (ACTUAL), 849.54345 (DESIRED)\n'
  464. ' [2]: 5498.42354 (ACTUAL), 0.0 (DESIRED)\n'
  465. 'Max absolute difference among violations: '
  466. '5498.42354\n'
  467. 'Max relative difference among violations: 5.4722099')
  468. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  469. self._assert_func(b, a, decimal=9)
  470. a = [5498.42354, 0.00]
  471. expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
  472. 'Mismatch at index:\n'
  473. ' [1]: 5498.42354 (ACTUAL), 0.0 (DESIRED)\n'
  474. 'Max absolute difference among violations: '
  475. '5498.42354\n'
  476. 'Max relative difference among violations: inf')
  477. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  478. self._assert_func(b, a, decimal=7)
  479. b = 0
  480. expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
  481. 'Mismatch at index:\n'
  482. ' [0]: 5498.42354 (ACTUAL), 0 (DESIRED)\n'
  483. 'Max absolute difference among violations: '
  484. '5498.42354\n'
  485. 'Max relative difference among violations: inf')
  486. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  487. self._assert_func(a, b, decimal=7)
  488. def test_nan(self):
  489. anan = np.array([np.nan])
  490. aone = np.array([1])
  491. ainf = np.array([np.inf])
  492. self._assert_func(anan, anan)
  493. assert_raises(AssertionError,
  494. lambda: self._assert_func(anan, aone))
  495. assert_raises(AssertionError,
  496. lambda: self._assert_func(anan, ainf))
  497. assert_raises(AssertionError,
  498. lambda: self._assert_func(ainf, anan))
  499. def test_inf(self):
  500. a = np.array([[1., 2.], [3., 4.]])
  501. b = a.copy()
  502. a[0, 0] = np.inf
  503. assert_raises(AssertionError,
  504. lambda: self._assert_func(a, b))
  505. b[0, 0] = -np.inf
  506. assert_raises(AssertionError,
  507. lambda: self._assert_func(a, b))
  508. def test_complex_inf(self):
  509. a = np.array([np.inf + 1.j, 2. + 1.j, 3. + 1.j])
  510. b = a.copy()
  511. self._assert_func(a, b)
  512. b[1] = 3. + 1.j
  513. expected_msg = ('Mismatched elements: 1 / 3 (33.3%)\n'
  514. 'Mismatch at index:\n'
  515. ' [1]: (2+1j) (ACTUAL), (3+1j) (DESIRED)\n'
  516. 'Max absolute difference among violations: 1.\n')
  517. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  518. self._assert_func(a, b)
  519. def test_subclass(self):
  520. a = np.array([[1., 2.], [3., 4.]])
  521. b = np.ma.masked_array([[1., 2.], [0., 4.]],
  522. [[False, False], [True, False]])
  523. self._assert_func(a, b)
  524. self._assert_func(b, a)
  525. self._assert_func(b, b)
  526. # Test fully masked as well (see gh-11123).
  527. a = np.ma.MaskedArray(3.5, mask=True)
  528. b = np.array([3., 4., 6.5])
  529. self._test_equal(a, b)
  530. self._test_equal(b, a)
  531. a = np.ma.masked
  532. b = np.array([3., 4., 6.5])
  533. self._test_equal(a, b)
  534. self._test_equal(b, a)
  535. a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, True, True])
  536. b = np.array([1., 2., 3.])
  537. self._test_equal(a, b)
  538. self._test_equal(b, a)
  539. a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, True, True])
  540. b = np.array(1.)
  541. self._test_equal(a, b)
  542. self._test_equal(b, a)
  543. def test_subclass_2(self):
  544. # While we cannot guarantee testing functions will always work for
  545. # subclasses, the tests should ideally rely only on subclasses having
  546. # comparison operators, not on them being able to store booleans
  547. # (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
  548. class MyArray(np.ndarray):
  549. def __eq__(self, other):
  550. return super().__eq__(other).view(np.ndarray)
  551. def __lt__(self, other):
  552. return super().__lt__(other).view(np.ndarray)
  553. def all(self, *args, **kwargs):
  554. return all(self)
  555. a = np.array([1., 2.]).view(MyArray)
  556. self._assert_func(a, a)
  557. z = np.array([True, True]).view(MyArray)
  558. all(z)
  559. b = np.array([1., 202]).view(MyArray)
  560. expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
  561. 'Mismatch at index:\n'
  562. ' [1]: 2.0 (ACTUAL), 202.0 (DESIRED)\n'
  563. 'Max absolute difference among violations: 200.\n'
  564. 'Max relative difference among violations: 0.99009')
  565. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  566. self._assert_func(a, b)
  567. def test_subclass_that_cannot_be_bool(self):
  568. # While we cannot guarantee testing functions will always work for
  569. # subclasses, the tests should ideally rely only on subclasses having
  570. # comparison operators, not on them being able to store booleans
  571. # (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
  572. class MyArray(np.ndarray):
  573. def __eq__(self, other):
  574. return super().__eq__(other).view(np.ndarray)
  575. def __lt__(self, other):
  576. return super().__lt__(other).view(np.ndarray)
  577. def all(self, *args, **kwargs):
  578. raise NotImplementedError
  579. a = np.array([1., 2.]).view(MyArray)
  580. self._assert_func(a, a)
  581. class TestAlmostEqual(_GenericTest):
  582. def _assert_func(self, *args, **kwargs):
  583. assert_almost_equal(*args, **kwargs)
  584. def test_closeness(self):
  585. # Note that in the course of time we ended up with
  586. # `abs(x - y) < 1.5 * 10**(-decimal)`
  587. # instead of the previously documented
  588. # `abs(x - y) < 0.5 * 10**(-decimal)`
  589. # so this check serves to preserve the wrongness.
  590. # test scalars
  591. self._assert_func(1.499999, 0.0, decimal=0)
  592. assert_raises(AssertionError,
  593. lambda: self._assert_func(1.5, 0.0, decimal=0))
  594. # test arrays
  595. self._assert_func([1.499999], [0.0], decimal=0)
  596. assert_raises(AssertionError,
  597. lambda: self._assert_func([1.5], [0.0], decimal=0))
  598. def test_nan_item(self):
  599. self._assert_func(np.nan, np.nan)
  600. assert_raises(AssertionError,
  601. lambda: self._assert_func(np.nan, 1))
  602. assert_raises(AssertionError,
  603. lambda: self._assert_func(np.nan, np.inf))
  604. assert_raises(AssertionError,
  605. lambda: self._assert_func(np.inf, np.nan))
  606. def test_inf_item(self):
  607. self._assert_func(np.inf, np.inf)
  608. self._assert_func(-np.inf, -np.inf)
  609. assert_raises(AssertionError,
  610. lambda: self._assert_func(np.inf, 1))
  611. assert_raises(AssertionError,
  612. lambda: self._assert_func(-np.inf, np.inf))
  613. def test_simple_item(self):
  614. self._test_not_equal(1, 2)
  615. def test_complex_item(self):
  616. self._assert_func(complex(1, 2), complex(1, 2))
  617. self._assert_func(complex(1, np.nan), complex(1, np.nan))
  618. self._assert_func(complex(np.inf, np.nan), complex(np.inf, np.nan))
  619. self._test_not_equal(complex(1, np.nan), complex(1, 2))
  620. self._test_not_equal(complex(np.nan, 1), complex(1, np.nan))
  621. self._test_not_equal(complex(np.nan, np.inf), complex(np.nan, 2))
  622. def test_complex(self):
  623. x = np.array([complex(1, 2), complex(1, np.nan)])
  624. z = np.array([complex(1, 2), complex(np.nan, 1)])
  625. y = np.array([complex(1, 2), complex(1, 2)])
  626. self._assert_func(x, x)
  627. self._test_not_equal(x, y)
  628. self._test_not_equal(x, z)
  629. def test_error_message(self):
  630. """Check the message is formatted correctly for the decimal value.
  631. Also check the message when input includes inf or nan (gh12200)"""
  632. x = np.array([1.00000000001, 2.00000000002, 3.00003])
  633. y = np.array([1.00000000002, 2.00000000003, 3.00004])
  634. # Test with a different amount of decimal digits
  635. expected_msg = ('Mismatched elements: 3 / 3 (100%)\n'
  636. 'Mismatch at indices:\n'
  637. ' [0]: 1.00000000001 (ACTUAL), 1.00000000002 (DESIRED)\n'
  638. ' [1]: 2.00000000002 (ACTUAL), 2.00000000003 (DESIRED)\n'
  639. ' [2]: 3.00003 (ACTUAL), 3.00004 (DESIRED)\n'
  640. 'Max absolute difference among violations: 1.e-05\n'
  641. 'Max relative difference among violations: '
  642. '3.33328889e-06\n'
  643. ' ACTUAL: array([1.00000000001, '
  644. '2.00000000002, '
  645. '3.00003 ])\n'
  646. ' DESIRED: array([1.00000000002, 2.00000000003, '
  647. '3.00004 ])')
  648. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  649. self._assert_func(x, y, decimal=12)
  650. # With the default value of decimal digits, only the 3rd element
  651. # differs. Note that we only check for the formatting of the arrays
  652. # themselves.
  653. expected_msg = ('Mismatched elements: 1 / 3 (33.3%)\n'
  654. 'Mismatch at index:\n'
  655. ' [2]: 3.00003 (ACTUAL), 3.00004 (DESIRED)\n'
  656. 'Max absolute difference among violations: 1.e-05\n'
  657. 'Max relative difference among violations: '
  658. '3.33328889e-06\n'
  659. ' ACTUAL: array([1. , 2. , 3.00003])\n'
  660. ' DESIRED: array([1. , 2. , 3.00004])')
  661. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  662. self._assert_func(x, y)
  663. # Check the error message when input includes inf
  664. x = np.array([np.inf, 0])
  665. y = np.array([np.inf, 1])
  666. expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
  667. 'Mismatch at index:\n'
  668. ' [1]: 0.0 (ACTUAL), 1.0 (DESIRED)\n'
  669. 'Max absolute difference among violations: 1.\n'
  670. 'Max relative difference among violations: 1.\n'
  671. ' ACTUAL: array([inf, 0.])\n'
  672. ' DESIRED: array([inf, 1.])')
  673. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  674. self._assert_func(x, y)
  675. # Check the error message when dividing by zero
  676. x = np.array([1, 2])
  677. y = np.array([0, 0])
  678. expected_msg = ('Mismatched elements: 2 / 2 (100%)\n'
  679. 'Mismatch at indices:\n'
  680. ' [0]: 1 (ACTUAL), 0 (DESIRED)\n'
  681. ' [1]: 2 (ACTUAL), 0 (DESIRED)\n'
  682. 'Max absolute difference among violations: 2\n'
  683. 'Max relative difference among violations: inf')
  684. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  685. self._assert_func(x, y)
  686. def test_error_message_2(self):
  687. """Check the message is formatted correctly """
  688. """when either x or y is a scalar."""
  689. x = 2
  690. y = np.ones(20)
  691. expected_msg = ('Mismatched elements: 20 / 20 (100%)\n'
  692. 'First 5 mismatches are at indices:\n'
  693. ' [0]: 2 (ACTUAL), 1.0 (DESIRED)\n'
  694. ' [1]: 2 (ACTUAL), 1.0 (DESIRED)\n'
  695. ' [2]: 2 (ACTUAL), 1.0 (DESIRED)\n'
  696. ' [3]: 2 (ACTUAL), 1.0 (DESIRED)\n'
  697. ' [4]: 2 (ACTUAL), 1.0 (DESIRED)\n'
  698. 'Max absolute difference among violations: 1.\n'
  699. 'Max relative difference among violations: 1.')
  700. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  701. self._assert_func(x, y)
  702. y = 2
  703. x = np.ones(20)
  704. expected_msg = ('Mismatched elements: 20 / 20 (100%)\n'
  705. 'First 5 mismatches are at indices:\n'
  706. ' [0]: 1.0 (ACTUAL), 2 (DESIRED)\n'
  707. ' [1]: 1.0 (ACTUAL), 2 (DESIRED)\n'
  708. ' [2]: 1.0 (ACTUAL), 2 (DESIRED)\n'
  709. ' [3]: 1.0 (ACTUAL), 2 (DESIRED)\n'
  710. ' [4]: 1.0 (ACTUAL), 2 (DESIRED)\n'
  711. 'Max absolute difference among violations: 1.\n'
  712. 'Max relative difference among violations: 0.5')
  713. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  714. self._assert_func(x, y)
  715. def test_subclass_that_cannot_be_bool(self):
  716. # While we cannot guarantee testing functions will always work for
  717. # subclasses, the tests should ideally rely only on subclasses having
  718. # comparison operators, not on them being able to store booleans
  719. # (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
  720. class MyArray(np.ndarray):
  721. def __eq__(self, other):
  722. return super().__eq__(other).view(np.ndarray)
  723. def __lt__(self, other):
  724. return super().__lt__(other).view(np.ndarray)
  725. def all(self, *args, **kwargs):
  726. raise NotImplementedError
  727. a = np.array([1., 2.]).view(MyArray)
  728. self._assert_func(a, a)
  729. class TestApproxEqual:
  730. def _assert_func(self, *args, **kwargs):
  731. assert_approx_equal(*args, **kwargs)
  732. def test_simple_0d_arrays(self):
  733. x = np.array(1234.22)
  734. y = np.array(1234.23)
  735. self._assert_func(x, y, significant=5)
  736. self._assert_func(x, y, significant=6)
  737. assert_raises(AssertionError,
  738. lambda: self._assert_func(x, y, significant=7))
  739. def test_simple_items(self):
  740. x = 1234.22
  741. y = 1234.23
  742. self._assert_func(x, y, significant=4)
  743. self._assert_func(x, y, significant=5)
  744. self._assert_func(x, y, significant=6)
  745. assert_raises(AssertionError,
  746. lambda: self._assert_func(x, y, significant=7))
  747. def test_nan_array(self):
  748. anan = np.array(np.nan)
  749. aone = np.array(1)
  750. ainf = np.array(np.inf)
  751. self._assert_func(anan, anan)
  752. assert_raises(AssertionError, lambda: self._assert_func(anan, aone))
  753. assert_raises(AssertionError, lambda: self._assert_func(anan, ainf))
  754. assert_raises(AssertionError, lambda: self._assert_func(ainf, anan))
  755. def test_nan_items(self):
  756. anan = np.array(np.nan)
  757. aone = np.array(1)
  758. ainf = np.array(np.inf)
  759. self._assert_func(anan, anan)
  760. assert_raises(AssertionError, lambda: self._assert_func(anan, aone))
  761. assert_raises(AssertionError, lambda: self._assert_func(anan, ainf))
  762. assert_raises(AssertionError, lambda: self._assert_func(ainf, anan))
  763. class TestArrayAssertLess:
  764. def _assert_func(self, *args, **kwargs):
  765. assert_array_less(*args, **kwargs)
  766. def test_simple_arrays(self):
  767. x = np.array([1.1, 2.2])
  768. y = np.array([1.2, 2.3])
  769. self._assert_func(x, y)
  770. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  771. y = np.array([1.0, 2.3])
  772. assert_raises(AssertionError, lambda: self._assert_func(x, y))
  773. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  774. a = np.array([1, 3, 6, 20])
  775. b = np.array([2, 4, 6, 8])
  776. expected_msg = ('Mismatched elements: 2 / 4 (50%)\n'
  777. 'Mismatch at indices:\n'
  778. ' [2]: 6 (x), 6 (y)\n'
  779. ' [3]: 20 (x), 8 (y)\n'
  780. 'Max absolute difference among violations: 12\n'
  781. 'Max relative difference among violations: 1.5')
  782. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  783. self._assert_func(a, b)
  784. def test_rank2(self):
  785. x = np.array([[1.1, 2.2], [3.3, 4.4]])
  786. y = np.array([[1.2, 2.3], [3.4, 4.5]])
  787. self._assert_func(x, y)
  788. expected_msg = ('Mismatched elements: 4 / 4 (100%)\n'
  789. 'Mismatch at indices:\n'
  790. ' [0, 0]: 1.2 (x), 1.1 (y)\n'
  791. ' [0, 1]: 2.3 (x), 2.2 (y)\n'
  792. ' [1, 0]: 3.4 (x), 3.3 (y)\n'
  793. ' [1, 1]: 4.5 (x), 4.4 (y)\n'
  794. 'Max absolute difference among violations: 0.1\n'
  795. 'Max relative difference among violations: 0.09090909')
  796. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  797. self._assert_func(y, x)
  798. y = np.array([[1.0, 2.3], [3.4, 4.5]])
  799. assert_raises(AssertionError, lambda: self._assert_func(x, y))
  800. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  801. def test_rank3(self):
  802. x = np.ones(shape=(2, 2, 2))
  803. y = np.ones(shape=(2, 2, 2)) + 1
  804. self._assert_func(x, y)
  805. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  806. y[0, 0, 0] = 0
  807. expected_msg = ('Mismatched elements: 1 / 8 (12.5%)\n'
  808. 'Mismatch at index:\n'
  809. ' [0, 0, 0]: 1.0 (x), 0.0 (y)\n'
  810. 'Max absolute difference among violations: 1.\n'
  811. 'Max relative difference among violations: inf')
  812. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  813. self._assert_func(x, y)
  814. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  815. def test_simple_items(self):
  816. x = 1.1
  817. y = 2.2
  818. self._assert_func(x, y)
  819. expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
  820. 'Max absolute difference among violations: 1.1\n'
  821. 'Max relative difference among violations: 1.')
  822. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  823. self._assert_func(y, x)
  824. y = np.array([2.2, 3.3])
  825. self._assert_func(x, y)
  826. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  827. y = np.array([1.0, 3.3])
  828. assert_raises(AssertionError, lambda: self._assert_func(x, y))
  829. def test_simple_items_and_array(self):
  830. x = np.array([[621.345454, 390.5436, 43.54657, 626.4535],
  831. [54.54, 627.3399, 13., 405.5435],
  832. [543.545, 8.34, 91.543, 333.3]])
  833. y = 627.34
  834. self._assert_func(x, y)
  835. y = 8.339999
  836. self._assert_func(y, x)
  837. x = np.array([[3.4536, 2390.5436, 435.54657, 324525.4535],
  838. [5449.54, 999090.54, 130303.54, 405.5435],
  839. [543.545, 8.34, 91.543, 999090.53999]])
  840. y = 999090.54
  841. expected_msg = ('Mismatched elements: 1 / 12 (8.33%)\n'
  842. 'Mismatch at index:\n'
  843. ' [1, 1]: 999090.54 (x), 999090.54 (y)\n'
  844. 'Max absolute difference among violations: 0.\n'
  845. 'Max relative difference among violations: 0.')
  846. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  847. self._assert_func(x, y)
  848. expected_msg = ('Mismatched elements: 12 / 12 (100%)\n'
  849. 'First 5 mismatches are at indices:\n'
  850. ' [0, 0]: 999090.54 (x), 3.4536 (y)\n'
  851. ' [0, 1]: 999090.54 (x), 2390.5436 (y)\n'
  852. ' [0, 2]: 999090.54 (x), 435.54657 (y)\n'
  853. ' [0, 3]: 999090.54 (x), 324525.4535 (y)\n'
  854. ' [1, 0]: 999090.54 (x), 5449.54 (y)\n'
  855. 'Max absolute difference among violations: '
  856. '999087.0864\n'
  857. 'Max relative difference among violations: '
  858. '289288.5934676')
  859. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  860. self._assert_func(y, x)
  861. def test_zeroes(self):
  862. x = np.array([546456., 0, 15.455])
  863. y = np.array(87654.)
  864. expected_msg = ('Mismatched elements: 1 / 3 (33.3%)\n'
  865. 'Mismatch at index:\n'
  866. ' [0]: 546456.0 (x), 87654.0 (y)\n'
  867. 'Max absolute difference among violations: 458802.\n'
  868. 'Max relative difference among violations: 5.23423917')
  869. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  870. self._assert_func(x, y)
  871. expected_msg = ('Mismatched elements: 2 / 3 (66.7%)\n'
  872. 'Mismatch at indices:\n'
  873. ' [1]: 87654.0 (x), 0.0 (y)\n'
  874. ' [2]: 87654.0 (x), 15.455 (y)\n'
  875. 'Max absolute difference among violations: 87654.\n'
  876. 'Max relative difference among violations: '
  877. '5670.5626011')
  878. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  879. self._assert_func(y, x)
  880. y = 0
  881. expected_msg = ('Mismatched elements: 3 / 3 (100%)\n'
  882. 'Mismatch at indices:\n'
  883. ' [0]: 546456.0 (x), 0 (y)\n'
  884. ' [1]: 0.0 (x), 0 (y)\n'
  885. ' [2]: 15.455 (x), 0 (y)\n'
  886. 'Max absolute difference among violations: 546456.\n'
  887. 'Max relative difference among violations: inf')
  888. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  889. self._assert_func(x, y)
  890. expected_msg = ('Mismatched elements: 1 / 3 (33.3%)\n'
  891. 'Mismatch at index:\n'
  892. ' [1]: 0 (x), 0.0 (y)\n'
  893. 'Max absolute difference among violations: 0.\n'
  894. 'Max relative difference among violations: inf')
  895. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  896. self._assert_func(y, x)
  897. def test_nan_noncompare(self):
  898. anan = np.array(np.nan)
  899. aone = np.array(1)
  900. ainf = np.array(np.inf)
  901. self._assert_func(anan, anan)
  902. assert_raises(AssertionError, lambda: self._assert_func(aone, anan))
  903. assert_raises(AssertionError, lambda: self._assert_func(anan, aone))
  904. assert_raises(AssertionError, lambda: self._assert_func(anan, ainf))
  905. assert_raises(AssertionError, lambda: self._assert_func(ainf, anan))
  906. def test_nan_noncompare_array(self):
  907. x = np.array([1.1, 2.2, 3.3])
  908. anan = np.array(np.nan)
  909. assert_raises(AssertionError, lambda: self._assert_func(x, anan))
  910. assert_raises(AssertionError, lambda: self._assert_func(anan, x))
  911. x = np.array([1.1, 2.2, np.nan])
  912. assert_raises(AssertionError, lambda: self._assert_func(x, anan))
  913. assert_raises(AssertionError, lambda: self._assert_func(anan, x))
  914. y = np.array([1.0, 2.0, np.nan])
  915. self._assert_func(y, x)
  916. assert_raises(AssertionError, lambda: self._assert_func(x, y))
  917. def test_inf_compare(self):
  918. aone = np.array(1)
  919. ainf = np.array(np.inf)
  920. self._assert_func(aone, ainf)
  921. self._assert_func(-ainf, aone)
  922. self._assert_func(-ainf, ainf)
  923. assert_raises(AssertionError, lambda: self._assert_func(ainf, aone))
  924. assert_raises(AssertionError, lambda: self._assert_func(aone, -ainf))
  925. assert_raises(AssertionError, lambda: self._assert_func(ainf, ainf))
  926. assert_raises(AssertionError, lambda: self._assert_func(ainf, -ainf))
  927. assert_raises(AssertionError, lambda: self._assert_func(-ainf, -ainf))
  928. def test_inf_compare_array(self):
  929. x = np.array([1.1, 2.2, np.inf])
  930. ainf = np.array(np.inf)
  931. assert_raises(AssertionError, lambda: self._assert_func(x, ainf))
  932. assert_raises(AssertionError, lambda: self._assert_func(ainf, x))
  933. assert_raises(AssertionError, lambda: self._assert_func(x, -ainf))
  934. assert_raises(AssertionError, lambda: self._assert_func(-x, -ainf))
  935. assert_raises(AssertionError, lambda: self._assert_func(-ainf, -x))
  936. self._assert_func(-ainf, x)
  937. def test_strict(self):
  938. """Test the behavior of the `strict` option."""
  939. x = np.zeros(3)
  940. y = np.ones(())
  941. self._assert_func(x, y)
  942. with pytest.raises(AssertionError):
  943. self._assert_func(x, y, strict=True)
  944. y = np.broadcast_to(y, x.shape)
  945. self._assert_func(x, y)
  946. with pytest.raises(AssertionError):
  947. self._assert_func(x, y.astype(np.float32), strict=True)
  948. @pytest.mark.filterwarnings(
  949. "ignore:.*NumPy warning suppression and assertion utilities are deprecated"
  950. ".*:DeprecationWarning")
  951. @pytest.mark.thread_unsafe(reason="checks global module & deprecated warnings")
  952. class TestWarns:
  953. def test_warn(self):
  954. def f():
  955. warnings.warn("yo")
  956. return 3
  957. before_filters = sys.modules['warnings'].filters[:]
  958. assert_equal(assert_warns(UserWarning, f), 3)
  959. after_filters = sys.modules['warnings'].filters
  960. assert_raises(AssertionError, assert_no_warnings, f)
  961. assert_equal(assert_no_warnings(lambda x: x, 1), 1)
  962. # Check that the warnings state is unchanged
  963. assert_equal(before_filters, after_filters,
  964. "assert_warns does not preserver warnings state")
  965. def test_context_manager(self):
  966. before_filters = sys.modules['warnings'].filters[:]
  967. with assert_warns(UserWarning):
  968. warnings.warn("yo")
  969. after_filters = sys.modules['warnings'].filters
  970. def no_warnings():
  971. with assert_no_warnings():
  972. warnings.warn("yo")
  973. assert_raises(AssertionError, no_warnings)
  974. assert_equal(before_filters, after_filters,
  975. "assert_warns does not preserver warnings state")
  976. def test_args(self):
  977. def f(a=0, b=1):
  978. warnings.warn("yo")
  979. return a + b
  980. assert assert_warns(UserWarning, f, b=20) == 20
  981. with pytest.raises(RuntimeError) as exc:
  982. # assert_warns cannot do regexp matching, use pytest.warns
  983. with assert_warns(UserWarning, match="A"):
  984. warnings.warn("B", UserWarning)
  985. assert "assert_warns" in str(exc)
  986. assert "pytest.warns" in str(exc)
  987. with pytest.raises(RuntimeError) as exc:
  988. # assert_warns cannot do regexp matching, use pytest.warns
  989. with assert_warns(UserWarning, wrong="A"):
  990. warnings.warn("B", UserWarning)
  991. assert "assert_warns" in str(exc)
  992. assert "pytest.warns" not in str(exc)
  993. def test_warn_wrong_warning(self):
  994. def f():
  995. warnings.warn("yo", DeprecationWarning)
  996. failed = False
  997. with warnings.catch_warnings():
  998. warnings.simplefilter("error", DeprecationWarning)
  999. try:
  1000. # Should raise a DeprecationWarning
  1001. assert_warns(UserWarning, f)
  1002. failed = True
  1003. except DeprecationWarning:
  1004. pass
  1005. if failed:
  1006. raise AssertionError("wrong warning caught by assert_warn")
  1007. class TestAssertAllclose:
  1008. def test_simple(self):
  1009. x = 1e-3
  1010. y = 1e-9
  1011. assert_allclose(x, y, atol=1)
  1012. assert_raises(AssertionError, assert_allclose, x, y)
  1013. expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
  1014. 'Max absolute difference among violations: 0.001\n'
  1015. 'Max relative difference among violations: 999999.')
  1016. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  1017. assert_allclose(x, y)
  1018. z = 0
  1019. expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
  1020. 'Max absolute difference among violations: 1.e-09\n'
  1021. 'Max relative difference among violations: inf')
  1022. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  1023. assert_allclose(y, z)
  1024. expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
  1025. 'Max absolute difference among violations: 1.e-09\n'
  1026. 'Max relative difference among violations: 1.')
  1027. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  1028. assert_allclose(z, y)
  1029. a = np.array([x, y, x, y])
  1030. b = np.array([x, y, x, x])
  1031. assert_allclose(a, b, atol=1)
  1032. assert_raises(AssertionError, assert_allclose, a, b)
  1033. b[-1] = y * (1 + 1e-8)
  1034. assert_allclose(a, b)
  1035. assert_raises(AssertionError, assert_allclose, a, b, rtol=1e-9)
  1036. assert_allclose(6, 10, rtol=0.5)
  1037. assert_raises(AssertionError, assert_allclose, 10, 6, rtol=0.5)
  1038. b = np.array([x, y, x, x])
  1039. c = np.array([x, y, x, z])
  1040. expected_msg = ('Mismatched elements: 1 / 4 (25%)\n'
  1041. 'Mismatch at index:\n'
  1042. ' [3]: 0.001 (ACTUAL), 0.0 (DESIRED)\n'
  1043. 'Max absolute difference among violations: 0.001\n'
  1044. 'Max relative difference among violations: inf')
  1045. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  1046. assert_allclose(b, c)
  1047. expected_msg = ('Mismatched elements: 1 / 4 (25%)\n'
  1048. 'Mismatch at index:\n'
  1049. ' [3]: 0.0 (ACTUAL), 0.001 (DESIRED)\n'
  1050. 'Max absolute difference among violations: 0.001\n'
  1051. 'Max relative difference among violations: 1.')
  1052. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  1053. assert_allclose(c, b)
  1054. def test_min_int(self):
  1055. a = np.array([np.iinfo(np.int_).min], dtype=np.int_)
  1056. # Should not raise:
  1057. assert_allclose(a, a)
  1058. def test_report_fail_percentage(self):
  1059. a = np.array([1, 1, 1, 1])
  1060. b = np.array([1, 1, 1, 2])
  1061. expected_msg = ('Mismatched elements: 1 / 4 (25%)\n'
  1062. 'Mismatch at index:\n'
  1063. ' [3]: 1 (ACTUAL), 2 (DESIRED)\n'
  1064. 'Max absolute difference among violations: 1\n'
  1065. 'Max relative difference among violations: 0.5')
  1066. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  1067. assert_allclose(a, b)
  1068. def test_equal_nan(self):
  1069. a = np.array([np.nan])
  1070. b = np.array([np.nan])
  1071. # Should not raise:
  1072. assert_allclose(a, b, equal_nan=True)
  1073. a = np.array([complex(np.nan, np.inf)])
  1074. b = np.array([complex(np.nan, np.inf)])
  1075. assert_allclose(a, b, equal_nan=True)
  1076. b = np.array([complex(np.nan, -np.inf)])
  1077. assert_allclose(a, b, equal_nan=True)
  1078. def test_not_equal_nan(self):
  1079. a = np.array([np.nan])
  1080. b = np.array([np.nan])
  1081. assert_raises(AssertionError, assert_allclose, a, b, equal_nan=False)
  1082. a = np.array([complex(np.nan, np.inf)])
  1083. b = np.array([complex(np.nan, np.inf)])
  1084. assert_raises(AssertionError, assert_allclose, a, b, equal_nan=False)
  1085. def test_equal_nan_default(self):
  1086. # Make sure equal_nan default behavior remains unchanged. (All
  1087. # of these functions use assert_array_compare under the hood.)
  1088. # None of these should raise.
  1089. a = np.array([np.nan])
  1090. b = np.array([np.nan])
  1091. assert_array_equal(a, b)
  1092. assert_array_almost_equal(a, b)
  1093. assert_array_less(a, b)
  1094. assert_allclose(a, b)
  1095. def test_report_max_relative_error(self):
  1096. a = np.array([0, 1])
  1097. b = np.array([0, 2])
  1098. expected_msg = 'Max relative difference among violations: 0.5'
  1099. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  1100. assert_allclose(a, b)
  1101. def test_timedelta(self):
  1102. # see gh-18286
  1103. a = np.array([[1, 2, 3, "NaT"]], dtype="m8[ns]")
  1104. assert_allclose(a, a)
  1105. def test_error_message_unsigned(self):
  1106. """Check the message is formatted correctly when overflow can occur
  1107. (gh21768)"""
  1108. # Ensure to test for potential overflow in the case of:
  1109. # x - y
  1110. # and
  1111. # y - x
  1112. x = np.asarray([0, 1, 8], dtype='uint8')
  1113. y = np.asarray([4, 4, 4], dtype='uint8')
  1114. expected_msg = 'Max absolute difference among violations: 4'
  1115. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  1116. assert_allclose(x, y, atol=3)
  1117. def test_strict(self):
  1118. """Test the behavior of the `strict` option."""
  1119. x = np.ones(3)
  1120. y = np.ones(())
  1121. assert_allclose(x, y)
  1122. with pytest.raises(AssertionError):
  1123. assert_allclose(x, y, strict=True)
  1124. assert_allclose(x, x)
  1125. with pytest.raises(AssertionError):
  1126. assert_allclose(x, x.astype(np.float32), strict=True)
  1127. def test_infs(self):
  1128. a = np.array([np.inf])
  1129. b = np.array([np.inf])
  1130. assert_allclose(a, b)
  1131. b = np.array([3.])
  1132. expected_msg = 'inf location mismatch:'
  1133. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  1134. assert_allclose(a, b)
  1135. b = np.array([-np.inf])
  1136. expected_msg = 'inf values mismatch:'
  1137. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  1138. assert_allclose(a, b)
  1139. b = np.array([complex(np.inf, 1.)])
  1140. expected_msg = 'inf values mismatch:'
  1141. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  1142. assert_allclose(a, b)
  1143. a = np.array([complex(np.inf, 1.)])
  1144. b = np.array([complex(np.inf, 1.)])
  1145. assert_allclose(a, b)
  1146. b = np.array([complex(np.inf, 2.)])
  1147. expected_msg = 'inf values mismatch:'
  1148. with pytest.raises(AssertionError, match=re.escape(expected_msg)):
  1149. assert_allclose(a, b)
  1150. class TestArrayAlmostEqualNulp:
  1151. def test_float64_pass(self):
  1152. # The number of units of least precision
  1153. # In this case, use a few places above the lowest level (ie nulp=1)
  1154. nulp = 5
  1155. x = np.linspace(-20, 20, 50, dtype=np.float64)
  1156. x = 10**x
  1157. x = np.r_[-x, x]
  1158. # Addition
  1159. eps = np.finfo(x.dtype).eps
  1160. y = x + x * eps * nulp / 2.
  1161. assert_array_almost_equal_nulp(x, y, nulp)
  1162. # Subtraction
  1163. epsneg = np.finfo(x.dtype).epsneg
  1164. y = x - x * epsneg * nulp / 2.
  1165. assert_array_almost_equal_nulp(x, y, nulp)
  1166. def test_float64_fail(self):
  1167. nulp = 5
  1168. x = np.linspace(-20, 20, 50, dtype=np.float64)
  1169. x = 10**x
  1170. x = np.r_[-x, x]
  1171. eps = np.finfo(x.dtype).eps
  1172. y = x + x * eps * nulp * 2.
  1173. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1174. x, y, nulp)
  1175. epsneg = np.finfo(x.dtype).epsneg
  1176. y = x - x * epsneg * nulp * 2.
  1177. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1178. x, y, nulp)
  1179. def test_float64_ignore_nan(self):
  1180. # Ignore ULP differences between various NAN's
  1181. # Note that MIPS may reverse quiet and signaling nans
  1182. # so we use the builtin version as a base.
  1183. offset = np.uint64(0xffffffff)
  1184. nan1_i64 = np.array(np.nan, dtype=np.float64).view(np.uint64)
  1185. nan2_i64 = nan1_i64 ^ offset # nan payload on MIPS is all ones.
  1186. nan1_f64 = nan1_i64.view(np.float64)
  1187. nan2_f64 = nan2_i64.view(np.float64)
  1188. assert_array_max_ulp(nan1_f64, nan2_f64, 0)
  1189. def test_float32_pass(self):
  1190. nulp = 5
  1191. x = np.linspace(-20, 20, 50, dtype=np.float32)
  1192. x = 10**x
  1193. x = np.r_[-x, x]
  1194. eps = np.finfo(x.dtype).eps
  1195. y = x + x * eps * nulp / 2.
  1196. assert_array_almost_equal_nulp(x, y, nulp)
  1197. epsneg = np.finfo(x.dtype).epsneg
  1198. y = x - x * epsneg * nulp / 2.
  1199. assert_array_almost_equal_nulp(x, y, nulp)
  1200. def test_float32_fail(self):
  1201. nulp = 5
  1202. x = np.linspace(-20, 20, 50, dtype=np.float32)
  1203. x = 10**x
  1204. x = np.r_[-x, x]
  1205. eps = np.finfo(x.dtype).eps
  1206. y = x + x * eps * nulp * 2.
  1207. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1208. x, y, nulp)
  1209. epsneg = np.finfo(x.dtype).epsneg
  1210. y = x - x * epsneg * nulp * 2.
  1211. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1212. x, y, nulp)
  1213. def test_float32_ignore_nan(self):
  1214. # Ignore ULP differences between various NAN's
  1215. # Note that MIPS may reverse quiet and signaling nans
  1216. # so we use the builtin version as a base.
  1217. offset = np.uint32(0xffff)
  1218. nan1_i32 = np.array(np.nan, dtype=np.float32).view(np.uint32)
  1219. nan2_i32 = nan1_i32 ^ offset # nan payload on MIPS is all ones.
  1220. nan1_f32 = nan1_i32.view(np.float32)
  1221. nan2_f32 = nan2_i32.view(np.float32)
  1222. assert_array_max_ulp(nan1_f32, nan2_f32, 0)
  1223. def test_float16_pass(self):
  1224. nulp = 5
  1225. x = np.linspace(-4, 4, 10, dtype=np.float16)
  1226. x = 10**x
  1227. x = np.r_[-x, x]
  1228. eps = np.finfo(x.dtype).eps
  1229. y = x + x * eps * nulp / 2.
  1230. assert_array_almost_equal_nulp(x, y, nulp)
  1231. epsneg = np.finfo(x.dtype).epsneg
  1232. y = x - x * epsneg * nulp / 2.
  1233. assert_array_almost_equal_nulp(x, y, nulp)
  1234. def test_float16_fail(self):
  1235. nulp = 5
  1236. x = np.linspace(-4, 4, 10, dtype=np.float16)
  1237. x = 10**x
  1238. x = np.r_[-x, x]
  1239. eps = np.finfo(x.dtype).eps
  1240. y = x + x * eps * nulp * 2.
  1241. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1242. x, y, nulp)
  1243. epsneg = np.finfo(x.dtype).epsneg
  1244. y = x - x * epsneg * nulp * 2.
  1245. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1246. x, y, nulp)
  1247. def test_float16_ignore_nan(self):
  1248. # Ignore ULP differences between various NAN's
  1249. # Note that MIPS may reverse quiet and signaling nans
  1250. # so we use the builtin version as a base.
  1251. offset = np.uint16(0xff)
  1252. nan1_i16 = np.array(np.nan, dtype=np.float16).view(np.uint16)
  1253. nan2_i16 = nan1_i16 ^ offset # nan payload on MIPS is all ones.
  1254. nan1_f16 = nan1_i16.view(np.float16)
  1255. nan2_f16 = nan2_i16.view(np.float16)
  1256. assert_array_max_ulp(nan1_f16, nan2_f16, 0)
  1257. def test_complex128_pass(self):
  1258. nulp = 5
  1259. x = np.linspace(-20, 20, 50, dtype=np.float64)
  1260. x = 10**x
  1261. x = np.r_[-x, x]
  1262. xi = x + x * 1j
  1263. eps = np.finfo(x.dtype).eps
  1264. y = x + x * eps * nulp / 2.
  1265. assert_array_almost_equal_nulp(xi, x + y * 1j, nulp)
  1266. assert_array_almost_equal_nulp(xi, y + x * 1j, nulp)
  1267. # The test condition needs to be at least a factor of sqrt(2) smaller
  1268. # because the real and imaginary parts both change
  1269. y = x + x * eps * nulp / 4.
  1270. assert_array_almost_equal_nulp(xi, y + y * 1j, nulp)
  1271. epsneg = np.finfo(x.dtype).epsneg
  1272. y = x - x * epsneg * nulp / 2.
  1273. assert_array_almost_equal_nulp(xi, x + y * 1j, nulp)
  1274. assert_array_almost_equal_nulp(xi, y + x * 1j, nulp)
  1275. y = x - x * epsneg * nulp / 4.
  1276. assert_array_almost_equal_nulp(xi, y + y * 1j, nulp)
  1277. def test_complex128_fail(self):
  1278. nulp = 5
  1279. x = np.linspace(-20, 20, 50, dtype=np.float64)
  1280. x = 10**x
  1281. x = np.r_[-x, x]
  1282. xi = x + x * 1j
  1283. eps = np.finfo(x.dtype).eps
  1284. y = x + x * eps * nulp * 2.
  1285. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1286. xi, x + y * 1j, nulp)
  1287. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1288. xi, y + x * 1j, nulp)
  1289. # The test condition needs to be at least a factor of sqrt(2) smaller
  1290. # because the real and imaginary parts both change
  1291. y = x + x * eps * nulp
  1292. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1293. xi, y + y * 1j, nulp)
  1294. epsneg = np.finfo(x.dtype).epsneg
  1295. y = x - x * epsneg * nulp * 2.
  1296. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1297. xi, x + y * 1j, nulp)
  1298. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1299. xi, y + x * 1j, nulp)
  1300. y = x - x * epsneg * nulp
  1301. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1302. xi, y + y * 1j, nulp)
  1303. def test_complex64_pass(self):
  1304. nulp = 5
  1305. x = np.linspace(-20, 20, 50, dtype=np.float32)
  1306. x = 10**x
  1307. x = np.r_[-x, x]
  1308. xi = x + x * 1j
  1309. eps = np.finfo(x.dtype).eps
  1310. y = x + x * eps * nulp / 2.
  1311. assert_array_almost_equal_nulp(xi, x + y * 1j, nulp)
  1312. assert_array_almost_equal_nulp(xi, y + x * 1j, nulp)
  1313. y = x + x * eps * nulp / 4.
  1314. assert_array_almost_equal_nulp(xi, y + y * 1j, nulp)
  1315. epsneg = np.finfo(x.dtype).epsneg
  1316. y = x - x * epsneg * nulp / 2.
  1317. assert_array_almost_equal_nulp(xi, x + y * 1j, nulp)
  1318. assert_array_almost_equal_nulp(xi, y + x * 1j, nulp)
  1319. y = x - x * epsneg * nulp / 4.
  1320. assert_array_almost_equal_nulp(xi, y + y * 1j, nulp)
  1321. def test_complex64_fail(self):
  1322. nulp = 5
  1323. x = np.linspace(-20, 20, 50, dtype=np.float32)
  1324. x = 10**x
  1325. x = np.r_[-x, x]
  1326. xi = x + x * 1j
  1327. eps = np.finfo(x.dtype).eps
  1328. y = x + x * eps * nulp * 2.
  1329. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1330. xi, x + y * 1j, nulp)
  1331. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1332. xi, y + x * 1j, nulp)
  1333. y = x + x * eps * nulp
  1334. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1335. xi, y + y * 1j, nulp)
  1336. epsneg = np.finfo(x.dtype).epsneg
  1337. y = x - x * epsneg * nulp * 2.
  1338. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1339. xi, x + y * 1j, nulp)
  1340. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1341. xi, y + x * 1j, nulp)
  1342. y = x - x * epsneg * nulp
  1343. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  1344. xi, y + y * 1j, nulp)
  1345. class TestULP:
  1346. def test_equal(self):
  1347. x = np.random.randn(10)
  1348. assert_array_max_ulp(x, x, maxulp=0)
  1349. def test_single(self):
  1350. # Generate 1 + small deviation, check that adding eps gives a few UNL
  1351. x = np.ones(10).astype(np.float32)
  1352. x += 0.01 * np.random.randn(10).astype(np.float32)
  1353. eps = np.finfo(np.float32).eps
  1354. assert_array_max_ulp(x, x + eps, maxulp=20)
  1355. def test_double(self):
  1356. # Generate 1 + small deviation, check that adding eps gives a few UNL
  1357. x = np.ones(10).astype(np.float64)
  1358. x += 0.01 * np.random.randn(10).astype(np.float64)
  1359. eps = np.finfo(np.float64).eps
  1360. assert_array_max_ulp(x, x + eps, maxulp=200)
  1361. def test_inf(self):
  1362. for dt in [np.float32, np.float64]:
  1363. inf = np.array([np.inf]).astype(dt)
  1364. big = np.array([np.finfo(dt).max])
  1365. assert_array_max_ulp(inf, big, maxulp=200)
  1366. def test_nan(self):
  1367. # Test that nan is 'far' from small, tiny, inf, max and min
  1368. for dt in [np.float32, np.float64]:
  1369. if dt == np.float32:
  1370. maxulp = 1e6
  1371. else:
  1372. maxulp = 1e12
  1373. inf = np.array([np.inf]).astype(dt)
  1374. nan = np.array([np.nan]).astype(dt)
  1375. big = np.array([np.finfo(dt).max])
  1376. tiny = np.array([np.finfo(dt).tiny])
  1377. zero = np.array([0.0]).astype(dt)
  1378. nzero = np.array([-0.0]).astype(dt)
  1379. assert_raises(AssertionError,
  1380. lambda: assert_array_max_ulp(nan, inf,
  1381. maxulp=maxulp))
  1382. assert_raises(AssertionError,
  1383. lambda: assert_array_max_ulp(nan, big,
  1384. maxulp=maxulp))
  1385. assert_raises(AssertionError,
  1386. lambda: assert_array_max_ulp(nan, tiny,
  1387. maxulp=maxulp))
  1388. assert_raises(AssertionError,
  1389. lambda: assert_array_max_ulp(nan, zero,
  1390. maxulp=maxulp))
  1391. assert_raises(AssertionError,
  1392. lambda: assert_array_max_ulp(nan, nzero,
  1393. maxulp=maxulp))
  1394. class TestStringEqual:
  1395. def test_simple(self):
  1396. assert_string_equal("hello", "hello")
  1397. assert_string_equal("hello\nmultiline", "hello\nmultiline")
  1398. with pytest.raises(AssertionError) as exc_info:
  1399. assert_string_equal("foo\nbar", "hello\nbar")
  1400. msg = str(exc_info.value)
  1401. assert_equal(msg, "Differences in strings:\n- foo\n+ hello")
  1402. assert_raises(AssertionError,
  1403. lambda: assert_string_equal("foo", "hello"))
  1404. def test_regex(self):
  1405. assert_string_equal("a+*b", "a+*b")
  1406. assert_raises(AssertionError,
  1407. lambda: assert_string_equal("aaa", "a+b"))
  1408. def assert_warn_len_equal(mod, n_in_context):
  1409. try:
  1410. mod_warns = mod.__warningregistry__
  1411. except AttributeError:
  1412. # the lack of a __warningregistry__
  1413. # attribute means that no warning has
  1414. # occurred; this can be triggered in
  1415. # a parallel test scenario, while in
  1416. # a serial test scenario an initial
  1417. # warning (and therefore the attribute)
  1418. # are always created first
  1419. mod_warns = {}
  1420. num_warns = len(mod_warns)
  1421. if 'version' in mod_warns:
  1422. # Python adds a 'version' entry to the registry,
  1423. # do not count it.
  1424. num_warns -= 1
  1425. assert_equal(num_warns, n_in_context)
  1426. def test_warn_len_equal_call_scenarios():
  1427. # assert_warn_len_equal is called under
  1428. # varying circumstances depending on serial
  1429. # vs. parallel test scenarios; this test
  1430. # simply aims to probe both code paths and
  1431. # check that no assertion is uncaught
  1432. # parallel scenario -- no warning issued yet
  1433. class mod:
  1434. pass
  1435. mod_inst = mod()
  1436. assert_warn_len_equal(mod=mod_inst,
  1437. n_in_context=0)
  1438. # serial test scenario -- the __warningregistry__
  1439. # attribute should be present
  1440. class mod:
  1441. def __init__(self):
  1442. self.__warningregistry__ = {'warning1': 1,
  1443. 'warning2': 2}
  1444. mod_inst = mod()
  1445. assert_warn_len_equal(mod=mod_inst,
  1446. n_in_context=2)
  1447. def _get_fresh_mod():
  1448. # Get this module, with warning registry empty
  1449. my_mod = sys.modules[__name__]
  1450. try:
  1451. my_mod.__warningregistry__.clear()
  1452. except AttributeError:
  1453. # will not have a __warningregistry__ unless warning has been
  1454. # raised in the module at some point
  1455. pass
  1456. return my_mod
  1457. @pytest.mark.thread_unsafe(reason="checks global module & deprecated warnings")
  1458. def test_clear_and_catch_warnings():
  1459. # Initial state of module, no warnings
  1460. my_mod = _get_fresh_mod()
  1461. assert_equal(getattr(my_mod, '__warningregistry__', {}), {})
  1462. with clear_and_catch_warnings(modules=[my_mod]):
  1463. warnings.simplefilter('ignore')
  1464. warnings.warn('Some warning')
  1465. assert_equal(my_mod.__warningregistry__, {})
  1466. # Without specified modules, don't clear warnings during context.
  1467. # catch_warnings doesn't make an entry for 'ignore'.
  1468. with clear_and_catch_warnings():
  1469. warnings.simplefilter('ignore')
  1470. warnings.warn('Some warning')
  1471. assert_warn_len_equal(my_mod, 0)
  1472. # Manually adding two warnings to the registry:
  1473. my_mod.__warningregistry__ = {'warning1': 1,
  1474. 'warning2': 2}
  1475. # Confirm that specifying module keeps old warning, does not add new
  1476. with clear_and_catch_warnings(modules=[my_mod]):
  1477. warnings.simplefilter('ignore')
  1478. warnings.warn('Another warning')
  1479. assert_warn_len_equal(my_mod, 2)
  1480. # Another warning, no module spec it clears up registry
  1481. with clear_and_catch_warnings():
  1482. warnings.simplefilter('ignore')
  1483. warnings.warn('Another warning')
  1484. assert_warn_len_equal(my_mod, 0)
  1485. @pytest.mark.filterwarnings(
  1486. "ignore:.*NumPy warning suppression and assertion utilities are deprecated"
  1487. ".*:DeprecationWarning")
  1488. @pytest.mark.thread_unsafe(reason="checks global module & deprecated warnings")
  1489. def test_suppress_warnings_module():
  1490. # Initial state of module, no warnings
  1491. my_mod = _get_fresh_mod()
  1492. assert_equal(getattr(my_mod, '__warningregistry__', {}), {})
  1493. def warn_other_module():
  1494. # Apply along axis is implemented in python; stacklevel=2 means
  1495. # we end up inside its module, not ours.
  1496. def warn(arr):
  1497. warnings.warn("Some warning 2", stacklevel=2)
  1498. return arr
  1499. np.apply_along_axis(warn, 0, [0])
  1500. # Test module based warning suppression:
  1501. assert_warn_len_equal(my_mod, 0)
  1502. with suppress_warnings() as sup:
  1503. sup.record(UserWarning)
  1504. # suppress warning from other module (may have .pyc ending),
  1505. # if apply_along_axis is moved, had to be changed.
  1506. sup.filter(module=np.lib._shape_base_impl)
  1507. warnings.warn("Some warning")
  1508. warn_other_module()
  1509. # Check that the suppression did test the file correctly (this module
  1510. # got filtered)
  1511. assert_equal(len(sup.log), 1)
  1512. assert_equal(sup.log[0].message.args[0], "Some warning")
  1513. assert_warn_len_equal(my_mod, 0)
  1514. sup = suppress_warnings()
  1515. # Will have to be changed if apply_along_axis is moved:
  1516. sup.filter(module=my_mod)
  1517. with sup:
  1518. warnings.warn('Some warning')
  1519. assert_warn_len_equal(my_mod, 0)
  1520. # And test repeat works:
  1521. sup.filter(module=my_mod)
  1522. with sup:
  1523. warnings.warn('Some warning')
  1524. assert_warn_len_equal(my_mod, 0)
  1525. # Without specified modules
  1526. with suppress_warnings():
  1527. warnings.simplefilter('ignore')
  1528. warnings.warn('Some warning')
  1529. assert_warn_len_equal(my_mod, 0)
  1530. @pytest.mark.filterwarnings(
  1531. "ignore:.*NumPy warning suppression and assertion utilities are deprecated"
  1532. ".*:DeprecationWarning")
  1533. @pytest.mark.thread_unsafe(reason="checks global module & deprecated warnings")
  1534. def test_suppress_warnings_type():
  1535. # Initial state of module, no warnings
  1536. my_mod = _get_fresh_mod()
  1537. assert_equal(getattr(my_mod, '__warningregistry__', {}), {})
  1538. # Test module based warning suppression:
  1539. with suppress_warnings() as sup:
  1540. sup.filter(UserWarning)
  1541. warnings.warn('Some warning')
  1542. assert_warn_len_equal(my_mod, 0)
  1543. sup = suppress_warnings()
  1544. sup.filter(UserWarning)
  1545. with sup:
  1546. warnings.warn('Some warning')
  1547. assert_warn_len_equal(my_mod, 0)
  1548. # And test repeat works:
  1549. sup.filter(module=my_mod)
  1550. with sup:
  1551. warnings.warn('Some warning')
  1552. assert_warn_len_equal(my_mod, 0)
  1553. # Without specified modules
  1554. with suppress_warnings():
  1555. warnings.simplefilter('ignore')
  1556. warnings.warn('Some warning')
  1557. assert_warn_len_equal(my_mod, 0)
  1558. @pytest.mark.filterwarnings(
  1559. "ignore:.*NumPy warning suppression and assertion utilities are deprecated"
  1560. ".*:DeprecationWarning")
  1561. @pytest.mark.thread_unsafe(
  1562. reason="uses deprecated thread-unsafe warnings control utilities"
  1563. )
  1564. def test_suppress_warnings_decorate_no_record():
  1565. sup = suppress_warnings()
  1566. sup.filter(UserWarning)
  1567. @sup
  1568. def warn(category):
  1569. warnings.warn('Some warning', category)
  1570. with warnings.catch_warnings(record=True) as w:
  1571. warnings.simplefilter("always")
  1572. warn(UserWarning) # should be suppressed
  1573. warn(RuntimeWarning)
  1574. assert_equal(len(w), 1)
  1575. @pytest.mark.filterwarnings(
  1576. "ignore:.*NumPy warning suppression and assertion utilities are deprecated"
  1577. ".*:DeprecationWarning")
  1578. @pytest.mark.thread_unsafe(
  1579. reason="uses deprecated thread-unsafe warnings control utilities"
  1580. )
  1581. def test_suppress_warnings_record():
  1582. sup = suppress_warnings()
  1583. log1 = sup.record()
  1584. with sup:
  1585. log2 = sup.record(message='Some other warning 2')
  1586. sup.filter(message='Some warning')
  1587. warnings.warn('Some warning')
  1588. warnings.warn('Some other warning')
  1589. warnings.warn('Some other warning 2')
  1590. assert_equal(len(sup.log), 2)
  1591. assert_equal(len(log1), 1)
  1592. assert_equal(len(log2), 1)
  1593. assert_equal(log2[0].message.args[0], 'Some other warning 2')
  1594. # Do it again, with the same context to see if some warnings survived:
  1595. with sup:
  1596. log2 = sup.record(message='Some other warning 2')
  1597. sup.filter(message='Some warning')
  1598. warnings.warn('Some warning')
  1599. warnings.warn('Some other warning')
  1600. warnings.warn('Some other warning 2')
  1601. assert_equal(len(sup.log), 2)
  1602. assert_equal(len(log1), 1)
  1603. assert_equal(len(log2), 1)
  1604. assert_equal(log2[0].message.args[0], 'Some other warning 2')
  1605. # Test nested:
  1606. with suppress_warnings() as sup:
  1607. sup.record()
  1608. with suppress_warnings() as sup2:
  1609. sup2.record(message='Some warning')
  1610. warnings.warn('Some warning')
  1611. warnings.warn('Some other warning')
  1612. assert_equal(len(sup2.log), 1)
  1613. # includes a DeprecationWarning for suppress_warnings
  1614. assert_equal(len(sup.log), 2)
  1615. @pytest.mark.filterwarnings(
  1616. "ignore:.*NumPy warning suppression and assertion utilities are deprecated"
  1617. ".*:DeprecationWarning")
  1618. @pytest.mark.thread_unsafe(
  1619. reason="uses deprecated thread-unsafe warnings control utilities"
  1620. )
  1621. def test_suppress_warnings_forwarding():
  1622. def warn_other_module():
  1623. # Apply along axis is implemented in python; stacklevel=2 means
  1624. # we end up inside its module, not ours.
  1625. def warn(arr):
  1626. warnings.warn("Some warning", stacklevel=2)
  1627. return arr
  1628. np.apply_along_axis(warn, 0, [0])
  1629. with suppress_warnings() as sup:
  1630. sup.record()
  1631. with suppress_warnings("always"):
  1632. for i in range(2):
  1633. warnings.warn("Some warning")
  1634. # includes a DeprecationWarning for suppress_warnings
  1635. assert_equal(len(sup.log), 3)
  1636. with suppress_warnings() as sup:
  1637. sup.record()
  1638. with suppress_warnings("location"):
  1639. for i in range(2):
  1640. warnings.warn("Some warning")
  1641. warnings.warn("Some warning")
  1642. # includes a DeprecationWarning for suppress_warnings
  1643. assert_equal(len(sup.log), 3)
  1644. with suppress_warnings() as sup:
  1645. sup.record()
  1646. with suppress_warnings("module"):
  1647. for i in range(2):
  1648. warnings.warn("Some warning")
  1649. warnings.warn("Some warning")
  1650. warn_other_module()
  1651. # includes a DeprecationWarning for suppress_warnings
  1652. assert_equal(len(sup.log), 3)
  1653. with suppress_warnings() as sup:
  1654. sup.record()
  1655. with suppress_warnings("once"):
  1656. for i in range(2):
  1657. warnings.warn("Some warning")
  1658. warnings.warn("Some other warning")
  1659. warn_other_module()
  1660. # includes a DeprecationWarning for suppress_warnings
  1661. assert_equal(len(sup.log), 3)
  1662. def test_tempdir():
  1663. with tempdir() as tdir:
  1664. fpath = os.path.join(tdir, 'tmp')
  1665. with open(fpath, 'w'):
  1666. pass
  1667. assert_(not os.path.isdir(tdir))
  1668. raised = False
  1669. try:
  1670. with tempdir() as tdir:
  1671. raise ValueError
  1672. except ValueError:
  1673. raised = True
  1674. assert_(raised)
  1675. assert_(not os.path.isdir(tdir))
  1676. def test_temppath():
  1677. with temppath() as fpath:
  1678. with open(fpath, 'w'):
  1679. pass
  1680. assert_(not os.path.isfile(fpath))
  1681. raised = False
  1682. try:
  1683. with temppath() as fpath:
  1684. raise ValueError
  1685. except ValueError:
  1686. raised = True
  1687. assert_(raised)
  1688. assert_(not os.path.isfile(fpath))
  1689. class my_cacw(clear_and_catch_warnings):
  1690. class_modules = (sys.modules[__name__],)
  1691. @pytest.mark.thread_unsafe(reason="checks global module & deprecated warnings")
  1692. def test_clear_and_catch_warnings_inherit():
  1693. # Test can subclass and add default modules
  1694. my_mod = _get_fresh_mod()
  1695. with my_cacw():
  1696. warnings.simplefilter('ignore')
  1697. warnings.warn('Some warning')
  1698. assert_equal(my_mod.__warningregistry__, {})
  1699. @pytest.mark.skipif(not HAS_REFCOUNT, reason="Python lacks refcounts")
  1700. @pytest.mark.thread_unsafe(reason="garbage collector is global state")
  1701. class TestAssertNoGcCycles:
  1702. """ Test assert_no_gc_cycles """
  1703. def test_passes(self):
  1704. def no_cycle():
  1705. b = []
  1706. b.append([])
  1707. return b
  1708. with assert_no_gc_cycles():
  1709. no_cycle()
  1710. assert_no_gc_cycles(no_cycle)
  1711. def test_asserts(self):
  1712. def make_cycle():
  1713. a = []
  1714. a.append(a)
  1715. a.append(a)
  1716. return a
  1717. with assert_raises(AssertionError):
  1718. with assert_no_gc_cycles():
  1719. make_cycle()
  1720. with assert_raises(AssertionError):
  1721. assert_no_gc_cycles(make_cycle)
  1722. @pytest.mark.slow
  1723. def test_fails(self):
  1724. """
  1725. Test that in cases where the garbage cannot be collected, we raise an
  1726. error, instead of hanging forever trying to clear it.
  1727. """
  1728. class ReferenceCycleInDel:
  1729. """
  1730. An object that not only contains a reference cycle, but creates new
  1731. cycles whenever it's garbage-collected and its __del__ runs
  1732. """
  1733. make_cycle = True
  1734. def __init__(self):
  1735. self.cycle = self
  1736. def __del__(self):
  1737. # break the current cycle so that `self` can be freed
  1738. self.cycle = None
  1739. if ReferenceCycleInDel.make_cycle:
  1740. # but create a new one so that the garbage collector (GC) has more
  1741. # work to do.
  1742. ReferenceCycleInDel()
  1743. try:
  1744. w = weakref.ref(ReferenceCycleInDel())
  1745. try:
  1746. with assert_raises(RuntimeError):
  1747. # this will be unable to get a baseline empty garbage
  1748. assert_no_gc_cycles(lambda: None)
  1749. except AssertionError:
  1750. # the above test is only necessary if the GC actually tried to free
  1751. # our object anyway.
  1752. if w() is not None:
  1753. pytest.skip("GC does not call __del__ on cyclic objects")
  1754. raise
  1755. finally:
  1756. # make sure that we stop creating reference cycles
  1757. ReferenceCycleInDel.make_cycle = False