| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123 |
- import itertools
- import os
- import re
- import sys
- import warnings
- import weakref
- import pytest
- import numpy as np
- import numpy._core._multiarray_umath as ncu
- from numpy.testing import (
- HAS_REFCOUNT,
- assert_,
- assert_allclose,
- assert_almost_equal,
- assert_approx_equal,
- assert_array_almost_equal,
- assert_array_almost_equal_nulp,
- assert_array_equal,
- assert_array_less,
- assert_array_max_ulp,
- assert_equal,
- assert_no_gc_cycles,
- assert_no_warnings,
- assert_raises,
- assert_string_equal,
- assert_warns,
- build_err_msg,
- clear_and_catch_warnings,
- suppress_warnings,
- tempdir,
- temppath,
- )
- class _GenericTest:
- def _assert_func(self, *args, **kwargs):
- pass
- def _test_equal(self, a, b):
- self._assert_func(a, b)
- def _test_not_equal(self, a, b):
- with assert_raises(AssertionError):
- self._assert_func(a, b)
- def test_array_rank1_eq(self):
- """Test two equal array of rank 1 are found equal."""
- a = np.array([1, 2])
- b = np.array([1, 2])
- self._test_equal(a, b)
- def test_array_rank1_noteq(self):
- """Test two different array of rank 1 are found not equal."""
- a = np.array([1, 2])
- b = np.array([2, 2])
- self._test_not_equal(a, b)
- def test_array_rank2_eq(self):
- """Test two equal array of rank 2 are found equal."""
- a = np.array([[1, 2], [3, 4]])
- b = np.array([[1, 2], [3, 4]])
- self._test_equal(a, b)
- def test_array_diffshape(self):
- """Test two arrays with different shapes are found not equal."""
- a = np.array([1, 2])
- b = np.array([[1, 2], [1, 2]])
- self._test_not_equal(a, b)
- def test_objarray(self):
- """Test object arrays."""
- a = np.array([1, 1], dtype=object)
- self._test_equal(a, 1)
- def test_array_likes(self):
- self._test_equal([1, 2, 3], (1, 2, 3))
- class TestArrayEqual(_GenericTest):
- def _assert_func(self, *args, **kwargs):
- assert_array_equal(*args, **kwargs)
- def test_generic_rank1(self):
- """Test rank 1 array for all dtypes."""
- def foo(t):
- a = np.empty(2, t)
- a.fill(1)
- b = a.copy()
- c = a.copy()
- c.fill(0)
- self._test_equal(a, b)
- self._test_not_equal(c, b)
- # Test numeric types and object
- for t in '?bhilqpBHILQPfdgFDG':
- foo(t)
- # Test strings
- for t in ['S1', 'U1']:
- foo(t)
- def test_0_ndim_array(self):
- x = np.array(473963742225900817127911193656584771)
- y = np.array(18535119325151578301457182298393896)
- with pytest.raises(AssertionError) as exc_info:
- self._assert_func(x, y)
- msg = str(exc_info.value)
- assert_('Mismatched elements: 1 / 1 (100%)\n'
- in msg)
- y = x
- self._assert_func(x, y)
- x = np.array(4395065348745.5643764887869876)
- y = np.array(0)
- expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
- 'Max absolute difference among violations: '
- '4.39506535e+12\n'
- 'Max relative difference among violations: inf\n')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(x, y)
- x = y
- self._assert_func(x, y)
- def test_generic_rank3(self):
- """Test rank 3 array for all dtypes."""
- def foo(t):
- a = np.empty((4, 2, 3), t)
- a.fill(1)
- b = a.copy()
- c = a.copy()
- c.fill(0)
- self._test_equal(a, b)
- self._test_not_equal(c, b)
- # Test numeric types and object
- for t in '?bhilqpBHILQPfdgFDG':
- foo(t)
- # Test strings
- for t in ['S1', 'U1']:
- foo(t)
- def test_nan_array(self):
- """Test arrays with nan values in them."""
- a = np.array([1, 2, np.nan])
- b = np.array([1, 2, np.nan])
- self._test_equal(a, b)
- c = np.array([1, 2, 3])
- self._test_not_equal(c, b)
- def test_string_arrays(self):
- """Test two arrays with different shapes are found not equal."""
- a = np.array(['floupi', 'floupa'])
- b = np.array(['floupi', 'floupa'])
- self._test_equal(a, b)
- c = np.array(['floupipi', 'floupa'])
- self._test_not_equal(c, b)
- def test_recarrays(self):
- """Test record arrays."""
- a = np.empty(2, [('floupi', float), ('floupa', float)])
- a['floupi'] = [1, 2]
- a['floupa'] = [1, 2]
- b = a.copy()
- self._test_equal(a, b)
- c = np.empty(2, [('floupipi', float),
- ('floupi', float), ('floupa', float)])
- c['floupipi'] = a['floupi'].copy()
- c['floupa'] = a['floupa'].copy()
- with pytest.raises(TypeError):
- self._test_not_equal(c, b)
- def test_masked_nan_inf(self):
- # Regression test for gh-11121
- a = np.ma.MaskedArray([3., 4., 6.5], mask=[False, True, False])
- b = np.array([3., np.nan, 6.5])
- self._test_equal(a, b)
- self._test_equal(b, a)
- a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, False, False])
- b = np.array([np.inf, 4., 6.5])
- self._test_equal(a, b)
- self._test_equal(b, a)
- # Also provides test cases for gh-11121
- def test_masked_scalar(self):
- # Test masked scalar vs. plain/masked scalar
- for a_val, b_val, b_masked in itertools.product(
- [3., np.nan, np.inf],
- [3., 4., np.nan, np.inf, -np.inf],
- [False, True],
- ):
- a = np.ma.MaskedArray(a_val, mask=True)
- b = np.ma.MaskedArray(b_val, mask=True) if b_masked else np.array(b_val)
- self._test_equal(a, b)
- self._test_equal(b, a)
- # Test masked scalar vs. plain array
- for a_val, b_val in itertools.product(
- [3., np.nan, -np.inf],
- itertools.product([3., 4., np.nan, np.inf, -np.inf], repeat=2),
- ):
- a = np.ma.MaskedArray(a_val, mask=True)
- b = np.array(b_val)
- self._test_equal(a, b)
- self._test_equal(b, a)
- # Test masked scalar vs. masked array
- for a_val, b_val, b_mask in itertools.product(
- [3., np.nan, np.inf],
- itertools.product([3., 4., np.nan, np.inf, -np.inf], repeat=2),
- itertools.product([False, True], repeat=2),
- ):
- a = np.ma.MaskedArray(a_val, mask=True)
- b = np.ma.MaskedArray(b_val, mask=b_mask)
- self._test_equal(a, b)
- self._test_equal(b, a)
- def test_subclass_that_overrides_eq(self):
- # While we cannot guarantee testing functions will always work for
- # subclasses, the tests should ideally rely only on subclasses having
- # comparison operators, not on them being able to store booleans
- # (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
- class MyArray(np.ndarray):
- def __eq__(self, other):
- return bool(np.equal(self, other).all())
- def __ne__(self, other):
- return not self == other
- a = np.array([1., 2.]).view(MyArray)
- b = np.array([2., 3.]).view(MyArray)
- assert_(type(a == a), bool)
- assert_(a == a)
- assert_(a != b)
- self._test_equal(a, a)
- self._test_not_equal(a, b)
- self._test_not_equal(b, a)
- expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
- 'Max absolute difference among violations: 1.\n'
- 'Max relative difference among violations: 0.5')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._test_equal(a, b)
- c = np.array([0., 2.9]).view(MyArray)
- expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
- 'Max absolute difference among violations: 2.\n'
- 'Max relative difference among violations: inf')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._test_equal(b, c)
- def test_subclass_that_does_not_implement_npall(self):
- class MyArray(np.ndarray):
- def __array_function__(self, *args, **kwargs):
- return NotImplemented
- a = np.array([1., 2.]).view(MyArray)
- b = np.array([2., 3.]).view(MyArray)
- with assert_raises(TypeError):
- np.all(a)
- self._test_equal(a, a)
- self._test_not_equal(a, b)
- self._test_not_equal(b, a)
- def test_suppress_overflow_warnings(self):
- # Based on issue #18992
- with pytest.raises(AssertionError):
- with np.errstate(all="raise"):
- np.testing.assert_array_equal(
- np.array([1, 2, 3], np.float32),
- np.array([1, 1e-40, 3], np.float32))
- def test_array_vs_scalar_is_equal(self):
- """Test comparing an array with a scalar when all values are equal."""
- a = np.array([1., 1., 1.])
- b = 1.
- self._test_equal(a, b)
- def test_array_vs_array_not_equal(self):
- """Test comparing an array with a scalar when not all values equal."""
- a = np.array([34986, 545676, 439655, 563766])
- b = np.array([34986, 545676, 439655, 0])
- expected_msg = ('Mismatched elements: 1 / 4 (25%)\n'
- 'Mismatch at index:\n'
- ' [3]: 563766 (ACTUAL), 0 (DESIRED)\n'
- 'Max absolute difference among violations: 563766\n'
- 'Max relative difference among violations: inf')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(a, b)
- a = np.array([34986, 545676, 439655.2, 563766])
- expected_msg = ('Mismatched elements: 2 / 4 (50%)\n'
- 'Mismatch at indices:\n'
- ' [2]: 439655.2 (ACTUAL), 439655 (DESIRED)\n'
- ' [3]: 563766.0 (ACTUAL), 0 (DESIRED)\n'
- 'Max absolute difference among violations: '
- '563766.\n'
- 'Max relative difference among violations: '
- '4.54902139e-07')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(a, b)
- def test_array_vs_scalar_strict(self):
- """Test comparing an array with a scalar with strict option."""
- a = np.array([1., 1., 1.])
- b = 1.
- with pytest.raises(AssertionError):
- self._assert_func(a, b, strict=True)
- def test_array_vs_array_strict(self):
- """Test comparing two arrays with strict option."""
- a = np.array([1., 1., 1.])
- b = np.array([1., 1., 1.])
- self._assert_func(a, b, strict=True)
- def test_array_vs_float_array_strict(self):
- """Test comparing two arrays with strict option."""
- a = np.array([1, 1, 1])
- b = np.array([1., 1., 1.])
- with pytest.raises(AssertionError):
- self._assert_func(a, b, strict=True)
- class TestBuildErrorMessage:
- def test_build_err_msg_defaults(self):
- x = np.array([1.00001, 2.00002, 3.00003])
- y = np.array([1.00002, 2.00003, 3.00004])
- err_msg = 'There is a mismatch'
- a = build_err_msg([x, y], err_msg)
- b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array(['
- '1.00001, 2.00002, 3.00003])\n DESIRED: array([1.00002, '
- '2.00003, 3.00004])')
- assert_equal(a, b)
- def test_build_err_msg_no_verbose(self):
- x = np.array([1.00001, 2.00002, 3.00003])
- y = np.array([1.00002, 2.00003, 3.00004])
- err_msg = 'There is a mismatch'
- a = build_err_msg([x, y], err_msg, verbose=False)
- b = '\nItems are not equal: There is a mismatch'
- assert_equal(a, b)
- def test_build_err_msg_custom_names(self):
- x = np.array([1.00001, 2.00002, 3.00003])
- y = np.array([1.00002, 2.00003, 3.00004])
- err_msg = 'There is a mismatch'
- a = build_err_msg([x, y], err_msg, names=('FOO', 'BAR'))
- b = ('\nItems are not equal: There is a mismatch\n FOO: array(['
- '1.00001, 2.00002, 3.00003])\n BAR: array([1.00002, 2.00003, '
- '3.00004])')
- assert_equal(a, b)
- def test_build_err_msg_custom_precision(self):
- x = np.array([1.000000001, 2.00002, 3.00003])
- y = np.array([1.000000002, 2.00003, 3.00004])
- err_msg = 'There is a mismatch'
- a = build_err_msg([x, y], err_msg, precision=10)
- b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array(['
- '1.000000001, 2.00002 , 3.00003 ])\n DESIRED: array(['
- '1.000000002, 2.00003 , 3.00004 ])')
- assert_equal(a, b)
- class TestEqual(TestArrayEqual):
- def _assert_func(self, *args, **kwargs):
- assert_equal(*args, **kwargs)
- def test_nan_items(self):
- self._assert_func(np.nan, np.nan)
- self._assert_func([np.nan], [np.nan])
- self._test_not_equal(np.nan, [np.nan])
- self._test_not_equal(np.nan, 1)
- def test_inf_items(self):
- self._assert_func(np.inf, np.inf)
- self._assert_func([np.inf], [np.inf])
- self._test_not_equal(np.inf, [np.inf])
- def test_datetime(self):
- self._test_equal(
- np.datetime64("2017-01-01", "s"),
- np.datetime64("2017-01-01", "s")
- )
- self._test_equal(
- np.datetime64("2017-01-01", "s"),
- np.datetime64("2017-01-01", "m")
- )
- # gh-10081
- self._test_not_equal(
- np.datetime64("2017-01-01", "s"),
- np.datetime64("2017-01-02", "s")
- )
- self._test_not_equal(
- np.datetime64("2017-01-01", "s"),
- np.datetime64("2017-01-02", "m")
- )
- def test_nat_items(self):
- # not a datetime
- nadt_no_unit = np.datetime64("NaT")
- nadt_s = np.datetime64("NaT", "s")
- nadt_d = np.datetime64("NaT", "ns")
- # not a timedelta
- natd_no_unit = np.timedelta64("NaT")
- natd_s = np.timedelta64("NaT", "s")
- natd_d = np.timedelta64("NaT", "ns")
- dts = [nadt_no_unit, nadt_s, nadt_d]
- tds = [natd_no_unit, natd_s, natd_d]
- for a, b in itertools.product(dts, dts):
- self._assert_func(a, b)
- self._assert_func([a], [b])
- self._test_not_equal([a], b)
- for a, b in itertools.product(tds, tds):
- self._assert_func(a, b)
- self._assert_func([a], [b])
- self._test_not_equal([a], b)
- for a, b in itertools.product(tds, dts):
- self._test_not_equal(a, b)
- self._test_not_equal(a, [b])
- self._test_not_equal([a], [b])
- self._test_not_equal([a], np.datetime64("2017-01-01", "s"))
- self._test_not_equal([b], np.datetime64("2017-01-01", "s"))
- self._test_not_equal([a], np.timedelta64(123, "s"))
- self._test_not_equal([b], np.timedelta64(123, "s"))
- def test_non_numeric(self):
- self._assert_func('ab', 'ab')
- self._test_not_equal('ab', 'abb')
- def test_complex_item(self):
- self._assert_func(complex(1, 2), complex(1, 2))
- self._assert_func(complex(1, np.nan), complex(1, np.nan))
- self._test_not_equal(complex(1, np.nan), complex(1, 2))
- self._test_not_equal(complex(np.nan, 1), complex(1, np.nan))
- self._test_not_equal(complex(np.nan, np.inf), complex(np.nan, 2))
- def test_negative_zero(self):
- self._test_not_equal(ncu.PZERO, ncu.NZERO)
- def test_complex(self):
- x = np.array([complex(1, 2), complex(1, np.nan)])
- y = np.array([complex(1, 2), complex(1, 2)])
- self._assert_func(x, x)
- self._test_not_equal(x, y)
- def test_object(self):
- # gh-12942
- import datetime
- a = np.array([datetime.datetime(2000, 1, 1),
- datetime.datetime(2000, 1, 2)])
- self._test_not_equal(a, a[::-1])
- class TestArrayAlmostEqual(_GenericTest):
- def _assert_func(self, *args, **kwargs):
- assert_array_almost_equal(*args, **kwargs)
- def test_closeness(self):
- # Note that in the course of time we ended up with
- # `abs(x - y) < 1.5 * 10**(-decimal)`
- # instead of the previously documented
- # `abs(x - y) < 0.5 * 10**(-decimal)`
- # so this check serves to preserve the wrongness.
- # test scalars
- expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
- 'Max absolute difference among violations: 1.5\n'
- 'Max relative difference among violations: inf')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(1.5, 0.0, decimal=0)
- # test arrays
- self._assert_func([1.499999], [0.0], decimal=0)
- expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
- 'Mismatch at index:\n'
- ' [0]: 1.5 (ACTUAL), 0.0 (DESIRED)\n'
- 'Max absolute difference among violations: 1.5\n'
- 'Max relative difference among violations: inf')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func([1.5], [0.0], decimal=0)
- a = [1.4999999, 0.00003]
- b = [1.49999991, 0]
- expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
- 'Mismatch at index:\n'
- ' [1]: 3e-05 (ACTUAL), 0.0 (DESIRED)\n'
- 'Max absolute difference among violations: 3.e-05\n'
- 'Max relative difference among violations: inf')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(a, b, decimal=7)
- expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
- 'Mismatch at index:\n'
- ' [1]: 0.0 (ACTUAL), 3e-05 (DESIRED)\n'
- 'Max absolute difference among violations: 3.e-05\n'
- 'Max relative difference among violations: 1.')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(b, a, decimal=7)
- def test_simple(self):
- x = np.array([1234.2222])
- y = np.array([1234.2223])
- self._assert_func(x, y, decimal=3)
- self._assert_func(x, y, decimal=4)
- expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
- 'Mismatch at index:\n'
- ' [0]: 1234.2222 (ACTUAL), 1234.2223 (DESIRED)\n'
- 'Max absolute difference among violations: '
- '1.e-04\n'
- 'Max relative difference among violations: '
- '8.10226812e-08')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(x, y, decimal=5)
- def test_array_vs_scalar(self):
- a = [5498.42354, 849.54345, 0.00]
- b = 5498.42354
- expected_msg = ('Mismatched elements: 2 / 3 (66.7%)\n'
- 'Mismatch at indices:\n'
- ' [1]: 849.54345 (ACTUAL), 5498.42354 (DESIRED)\n'
- ' [2]: 0.0 (ACTUAL), 5498.42354 (DESIRED)\n'
- 'Max absolute difference among violations: '
- '5498.42354\n'
- 'Max relative difference among violations: 1.')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(a, b, decimal=9)
- expected_msg = ('Mismatched elements: 2 / 3 (66.7%)\n'
- 'Mismatch at indices:\n'
- ' [1]: 5498.42354 (ACTUAL), 849.54345 (DESIRED)\n'
- ' [2]: 5498.42354 (ACTUAL), 0.0 (DESIRED)\n'
- 'Max absolute difference among violations: '
- '5498.42354\n'
- 'Max relative difference among violations: 5.4722099')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(b, a, decimal=9)
- a = [5498.42354, 0.00]
- expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
- 'Mismatch at index:\n'
- ' [1]: 5498.42354 (ACTUAL), 0.0 (DESIRED)\n'
- 'Max absolute difference among violations: '
- '5498.42354\n'
- 'Max relative difference among violations: inf')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(b, a, decimal=7)
- b = 0
- expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
- 'Mismatch at index:\n'
- ' [0]: 5498.42354 (ACTUAL), 0 (DESIRED)\n'
- 'Max absolute difference among violations: '
- '5498.42354\n'
- 'Max relative difference among violations: inf')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(a, b, decimal=7)
- def test_nan(self):
- anan = np.array([np.nan])
- aone = np.array([1])
- ainf = np.array([np.inf])
- self._assert_func(anan, anan)
- assert_raises(AssertionError,
- lambda: self._assert_func(anan, aone))
- assert_raises(AssertionError,
- lambda: self._assert_func(anan, ainf))
- assert_raises(AssertionError,
- lambda: self._assert_func(ainf, anan))
- def test_inf(self):
- a = np.array([[1., 2.], [3., 4.]])
- b = a.copy()
- a[0, 0] = np.inf
- assert_raises(AssertionError,
- lambda: self._assert_func(a, b))
- b[0, 0] = -np.inf
- assert_raises(AssertionError,
- lambda: self._assert_func(a, b))
- def test_complex_inf(self):
- a = np.array([np.inf + 1.j, 2. + 1.j, 3. + 1.j])
- b = a.copy()
- self._assert_func(a, b)
- b[1] = 3. + 1.j
- expected_msg = ('Mismatched elements: 1 / 3 (33.3%)\n'
- 'Mismatch at index:\n'
- ' [1]: (2+1j) (ACTUAL), (3+1j) (DESIRED)\n'
- 'Max absolute difference among violations: 1.\n')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(a, b)
- def test_subclass(self):
- a = np.array([[1., 2.], [3., 4.]])
- b = np.ma.masked_array([[1., 2.], [0., 4.]],
- [[False, False], [True, False]])
- self._assert_func(a, b)
- self._assert_func(b, a)
- self._assert_func(b, b)
- # Test fully masked as well (see gh-11123).
- a = np.ma.MaskedArray(3.5, mask=True)
- b = np.array([3., 4., 6.5])
- self._test_equal(a, b)
- self._test_equal(b, a)
- a = np.ma.masked
- b = np.array([3., 4., 6.5])
- self._test_equal(a, b)
- self._test_equal(b, a)
- a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, True, True])
- b = np.array([1., 2., 3.])
- self._test_equal(a, b)
- self._test_equal(b, a)
- a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, True, True])
- b = np.array(1.)
- self._test_equal(a, b)
- self._test_equal(b, a)
- def test_subclass_2(self):
- # While we cannot guarantee testing functions will always work for
- # subclasses, the tests should ideally rely only on subclasses having
- # comparison operators, not on them being able to store booleans
- # (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
- class MyArray(np.ndarray):
- def __eq__(self, other):
- return super().__eq__(other).view(np.ndarray)
- def __lt__(self, other):
- return super().__lt__(other).view(np.ndarray)
- def all(self, *args, **kwargs):
- return all(self)
- a = np.array([1., 2.]).view(MyArray)
- self._assert_func(a, a)
- z = np.array([True, True]).view(MyArray)
- all(z)
- b = np.array([1., 202]).view(MyArray)
- expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
- 'Mismatch at index:\n'
- ' [1]: 2.0 (ACTUAL), 202.0 (DESIRED)\n'
- 'Max absolute difference among violations: 200.\n'
- 'Max relative difference among violations: 0.99009')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(a, b)
- def test_subclass_that_cannot_be_bool(self):
- # While we cannot guarantee testing functions will always work for
- # subclasses, the tests should ideally rely only on subclasses having
- # comparison operators, not on them being able to store booleans
- # (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
- class MyArray(np.ndarray):
- def __eq__(self, other):
- return super().__eq__(other).view(np.ndarray)
- def __lt__(self, other):
- return super().__lt__(other).view(np.ndarray)
- def all(self, *args, **kwargs):
- raise NotImplementedError
- a = np.array([1., 2.]).view(MyArray)
- self._assert_func(a, a)
- class TestAlmostEqual(_GenericTest):
- def _assert_func(self, *args, **kwargs):
- assert_almost_equal(*args, **kwargs)
- def test_closeness(self):
- # Note that in the course of time we ended up with
- # `abs(x - y) < 1.5 * 10**(-decimal)`
- # instead of the previously documented
- # `abs(x - y) < 0.5 * 10**(-decimal)`
- # so this check serves to preserve the wrongness.
- # test scalars
- self._assert_func(1.499999, 0.0, decimal=0)
- assert_raises(AssertionError,
- lambda: self._assert_func(1.5, 0.0, decimal=0))
- # test arrays
- self._assert_func([1.499999], [0.0], decimal=0)
- assert_raises(AssertionError,
- lambda: self._assert_func([1.5], [0.0], decimal=0))
- def test_nan_item(self):
- self._assert_func(np.nan, np.nan)
- assert_raises(AssertionError,
- lambda: self._assert_func(np.nan, 1))
- assert_raises(AssertionError,
- lambda: self._assert_func(np.nan, np.inf))
- assert_raises(AssertionError,
- lambda: self._assert_func(np.inf, np.nan))
- def test_inf_item(self):
- self._assert_func(np.inf, np.inf)
- self._assert_func(-np.inf, -np.inf)
- assert_raises(AssertionError,
- lambda: self._assert_func(np.inf, 1))
- assert_raises(AssertionError,
- lambda: self._assert_func(-np.inf, np.inf))
- def test_simple_item(self):
- self._test_not_equal(1, 2)
- def test_complex_item(self):
- self._assert_func(complex(1, 2), complex(1, 2))
- self._assert_func(complex(1, np.nan), complex(1, np.nan))
- self._assert_func(complex(np.inf, np.nan), complex(np.inf, np.nan))
- self._test_not_equal(complex(1, np.nan), complex(1, 2))
- self._test_not_equal(complex(np.nan, 1), complex(1, np.nan))
- self._test_not_equal(complex(np.nan, np.inf), complex(np.nan, 2))
- def test_complex(self):
- x = np.array([complex(1, 2), complex(1, np.nan)])
- z = np.array([complex(1, 2), complex(np.nan, 1)])
- y = np.array([complex(1, 2), complex(1, 2)])
- self._assert_func(x, x)
- self._test_not_equal(x, y)
- self._test_not_equal(x, z)
- def test_error_message(self):
- """Check the message is formatted correctly for the decimal value.
- Also check the message when input includes inf or nan (gh12200)"""
- x = np.array([1.00000000001, 2.00000000002, 3.00003])
- y = np.array([1.00000000002, 2.00000000003, 3.00004])
- # Test with a different amount of decimal digits
- expected_msg = ('Mismatched elements: 3 / 3 (100%)\n'
- 'Mismatch at indices:\n'
- ' [0]: 1.00000000001 (ACTUAL), 1.00000000002 (DESIRED)\n'
- ' [1]: 2.00000000002 (ACTUAL), 2.00000000003 (DESIRED)\n'
- ' [2]: 3.00003 (ACTUAL), 3.00004 (DESIRED)\n'
- 'Max absolute difference among violations: 1.e-05\n'
- 'Max relative difference among violations: '
- '3.33328889e-06\n'
- ' ACTUAL: array([1.00000000001, '
- '2.00000000002, '
- '3.00003 ])\n'
- ' DESIRED: array([1.00000000002, 2.00000000003, '
- '3.00004 ])')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(x, y, decimal=12)
- # With the default value of decimal digits, only the 3rd element
- # differs. Note that we only check for the formatting of the arrays
- # themselves.
- expected_msg = ('Mismatched elements: 1 / 3 (33.3%)\n'
- 'Mismatch at index:\n'
- ' [2]: 3.00003 (ACTUAL), 3.00004 (DESIRED)\n'
- 'Max absolute difference among violations: 1.e-05\n'
- 'Max relative difference among violations: '
- '3.33328889e-06\n'
- ' ACTUAL: array([1. , 2. , 3.00003])\n'
- ' DESIRED: array([1. , 2. , 3.00004])')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(x, y)
- # Check the error message when input includes inf
- x = np.array([np.inf, 0])
- y = np.array([np.inf, 1])
- expected_msg = ('Mismatched elements: 1 / 2 (50%)\n'
- 'Mismatch at index:\n'
- ' [1]: 0.0 (ACTUAL), 1.0 (DESIRED)\n'
- 'Max absolute difference among violations: 1.\n'
- 'Max relative difference among violations: 1.\n'
- ' ACTUAL: array([inf, 0.])\n'
- ' DESIRED: array([inf, 1.])')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(x, y)
- # Check the error message when dividing by zero
- x = np.array([1, 2])
- y = np.array([0, 0])
- expected_msg = ('Mismatched elements: 2 / 2 (100%)\n'
- 'Mismatch at indices:\n'
- ' [0]: 1 (ACTUAL), 0 (DESIRED)\n'
- ' [1]: 2 (ACTUAL), 0 (DESIRED)\n'
- 'Max absolute difference among violations: 2\n'
- 'Max relative difference among violations: inf')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(x, y)
- def test_error_message_2(self):
- """Check the message is formatted correctly """
- """when either x or y is a scalar."""
- x = 2
- y = np.ones(20)
- expected_msg = ('Mismatched elements: 20 / 20 (100%)\n'
- 'First 5 mismatches are at indices:\n'
- ' [0]: 2 (ACTUAL), 1.0 (DESIRED)\n'
- ' [1]: 2 (ACTUAL), 1.0 (DESIRED)\n'
- ' [2]: 2 (ACTUAL), 1.0 (DESIRED)\n'
- ' [3]: 2 (ACTUAL), 1.0 (DESIRED)\n'
- ' [4]: 2 (ACTUAL), 1.0 (DESIRED)\n'
- 'Max absolute difference among violations: 1.\n'
- 'Max relative difference among violations: 1.')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(x, y)
- y = 2
- x = np.ones(20)
- expected_msg = ('Mismatched elements: 20 / 20 (100%)\n'
- 'First 5 mismatches are at indices:\n'
- ' [0]: 1.0 (ACTUAL), 2 (DESIRED)\n'
- ' [1]: 1.0 (ACTUAL), 2 (DESIRED)\n'
- ' [2]: 1.0 (ACTUAL), 2 (DESIRED)\n'
- ' [3]: 1.0 (ACTUAL), 2 (DESIRED)\n'
- ' [4]: 1.0 (ACTUAL), 2 (DESIRED)\n'
- 'Max absolute difference among violations: 1.\n'
- 'Max relative difference among violations: 0.5')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(x, y)
- def test_subclass_that_cannot_be_bool(self):
- # While we cannot guarantee testing functions will always work for
- # subclasses, the tests should ideally rely only on subclasses having
- # comparison operators, not on them being able to store booleans
- # (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
- class MyArray(np.ndarray):
- def __eq__(self, other):
- return super().__eq__(other).view(np.ndarray)
- def __lt__(self, other):
- return super().__lt__(other).view(np.ndarray)
- def all(self, *args, **kwargs):
- raise NotImplementedError
- a = np.array([1., 2.]).view(MyArray)
- self._assert_func(a, a)
- class TestApproxEqual:
- def _assert_func(self, *args, **kwargs):
- assert_approx_equal(*args, **kwargs)
- def test_simple_0d_arrays(self):
- x = np.array(1234.22)
- y = np.array(1234.23)
- self._assert_func(x, y, significant=5)
- self._assert_func(x, y, significant=6)
- assert_raises(AssertionError,
- lambda: self._assert_func(x, y, significant=7))
- def test_simple_items(self):
- x = 1234.22
- y = 1234.23
- self._assert_func(x, y, significant=4)
- self._assert_func(x, y, significant=5)
- self._assert_func(x, y, significant=6)
- assert_raises(AssertionError,
- lambda: self._assert_func(x, y, significant=7))
- def test_nan_array(self):
- anan = np.array(np.nan)
- aone = np.array(1)
- ainf = np.array(np.inf)
- self._assert_func(anan, anan)
- assert_raises(AssertionError, lambda: self._assert_func(anan, aone))
- assert_raises(AssertionError, lambda: self._assert_func(anan, ainf))
- assert_raises(AssertionError, lambda: self._assert_func(ainf, anan))
- def test_nan_items(self):
- anan = np.array(np.nan)
- aone = np.array(1)
- ainf = np.array(np.inf)
- self._assert_func(anan, anan)
- assert_raises(AssertionError, lambda: self._assert_func(anan, aone))
- assert_raises(AssertionError, lambda: self._assert_func(anan, ainf))
- assert_raises(AssertionError, lambda: self._assert_func(ainf, anan))
- class TestArrayAssertLess:
- def _assert_func(self, *args, **kwargs):
- assert_array_less(*args, **kwargs)
- def test_simple_arrays(self):
- x = np.array([1.1, 2.2])
- y = np.array([1.2, 2.3])
- self._assert_func(x, y)
- assert_raises(AssertionError, lambda: self._assert_func(y, x))
- y = np.array([1.0, 2.3])
- assert_raises(AssertionError, lambda: self._assert_func(x, y))
- assert_raises(AssertionError, lambda: self._assert_func(y, x))
- a = np.array([1, 3, 6, 20])
- b = np.array([2, 4, 6, 8])
- expected_msg = ('Mismatched elements: 2 / 4 (50%)\n'
- 'Mismatch at indices:\n'
- ' [2]: 6 (x), 6 (y)\n'
- ' [3]: 20 (x), 8 (y)\n'
- 'Max absolute difference among violations: 12\n'
- 'Max relative difference among violations: 1.5')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(a, b)
- def test_rank2(self):
- x = np.array([[1.1, 2.2], [3.3, 4.4]])
- y = np.array([[1.2, 2.3], [3.4, 4.5]])
- self._assert_func(x, y)
- expected_msg = ('Mismatched elements: 4 / 4 (100%)\n'
- 'Mismatch at indices:\n'
- ' [0, 0]: 1.2 (x), 1.1 (y)\n'
- ' [0, 1]: 2.3 (x), 2.2 (y)\n'
- ' [1, 0]: 3.4 (x), 3.3 (y)\n'
- ' [1, 1]: 4.5 (x), 4.4 (y)\n'
- 'Max absolute difference among violations: 0.1\n'
- 'Max relative difference among violations: 0.09090909')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(y, x)
- y = np.array([[1.0, 2.3], [3.4, 4.5]])
- assert_raises(AssertionError, lambda: self._assert_func(x, y))
- assert_raises(AssertionError, lambda: self._assert_func(y, x))
- def test_rank3(self):
- x = np.ones(shape=(2, 2, 2))
- y = np.ones(shape=(2, 2, 2)) + 1
- self._assert_func(x, y)
- assert_raises(AssertionError, lambda: self._assert_func(y, x))
- y[0, 0, 0] = 0
- expected_msg = ('Mismatched elements: 1 / 8 (12.5%)\n'
- 'Mismatch at index:\n'
- ' [0, 0, 0]: 1.0 (x), 0.0 (y)\n'
- 'Max absolute difference among violations: 1.\n'
- 'Max relative difference among violations: inf')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(x, y)
- assert_raises(AssertionError, lambda: self._assert_func(y, x))
- def test_simple_items(self):
- x = 1.1
- y = 2.2
- self._assert_func(x, y)
- expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
- 'Max absolute difference among violations: 1.1\n'
- 'Max relative difference among violations: 1.')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(y, x)
- y = np.array([2.2, 3.3])
- self._assert_func(x, y)
- assert_raises(AssertionError, lambda: self._assert_func(y, x))
- y = np.array([1.0, 3.3])
- assert_raises(AssertionError, lambda: self._assert_func(x, y))
- def test_simple_items_and_array(self):
- x = np.array([[621.345454, 390.5436, 43.54657, 626.4535],
- [54.54, 627.3399, 13., 405.5435],
- [543.545, 8.34, 91.543, 333.3]])
- y = 627.34
- self._assert_func(x, y)
- y = 8.339999
- self._assert_func(y, x)
- x = np.array([[3.4536, 2390.5436, 435.54657, 324525.4535],
- [5449.54, 999090.54, 130303.54, 405.5435],
- [543.545, 8.34, 91.543, 999090.53999]])
- y = 999090.54
- expected_msg = ('Mismatched elements: 1 / 12 (8.33%)\n'
- 'Mismatch at index:\n'
- ' [1, 1]: 999090.54 (x), 999090.54 (y)\n'
- 'Max absolute difference among violations: 0.\n'
- 'Max relative difference among violations: 0.')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(x, y)
- expected_msg = ('Mismatched elements: 12 / 12 (100%)\n'
- 'First 5 mismatches are at indices:\n'
- ' [0, 0]: 999090.54 (x), 3.4536 (y)\n'
- ' [0, 1]: 999090.54 (x), 2390.5436 (y)\n'
- ' [0, 2]: 999090.54 (x), 435.54657 (y)\n'
- ' [0, 3]: 999090.54 (x), 324525.4535 (y)\n'
- ' [1, 0]: 999090.54 (x), 5449.54 (y)\n'
- 'Max absolute difference among violations: '
- '999087.0864\n'
- 'Max relative difference among violations: '
- '289288.5934676')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(y, x)
- def test_zeroes(self):
- x = np.array([546456., 0, 15.455])
- y = np.array(87654.)
- expected_msg = ('Mismatched elements: 1 / 3 (33.3%)\n'
- 'Mismatch at index:\n'
- ' [0]: 546456.0 (x), 87654.0 (y)\n'
- 'Max absolute difference among violations: 458802.\n'
- 'Max relative difference among violations: 5.23423917')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(x, y)
- expected_msg = ('Mismatched elements: 2 / 3 (66.7%)\n'
- 'Mismatch at indices:\n'
- ' [1]: 87654.0 (x), 0.0 (y)\n'
- ' [2]: 87654.0 (x), 15.455 (y)\n'
- 'Max absolute difference among violations: 87654.\n'
- 'Max relative difference among violations: '
- '5670.5626011')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(y, x)
- y = 0
- expected_msg = ('Mismatched elements: 3 / 3 (100%)\n'
- 'Mismatch at indices:\n'
- ' [0]: 546456.0 (x), 0 (y)\n'
- ' [1]: 0.0 (x), 0 (y)\n'
- ' [2]: 15.455 (x), 0 (y)\n'
- 'Max absolute difference among violations: 546456.\n'
- 'Max relative difference among violations: inf')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(x, y)
- expected_msg = ('Mismatched elements: 1 / 3 (33.3%)\n'
- 'Mismatch at index:\n'
- ' [1]: 0 (x), 0.0 (y)\n'
- 'Max absolute difference among violations: 0.\n'
- 'Max relative difference among violations: inf')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- self._assert_func(y, x)
- def test_nan_noncompare(self):
- anan = np.array(np.nan)
- aone = np.array(1)
- ainf = np.array(np.inf)
- self._assert_func(anan, anan)
- assert_raises(AssertionError, lambda: self._assert_func(aone, anan))
- assert_raises(AssertionError, lambda: self._assert_func(anan, aone))
- assert_raises(AssertionError, lambda: self._assert_func(anan, ainf))
- assert_raises(AssertionError, lambda: self._assert_func(ainf, anan))
- def test_nan_noncompare_array(self):
- x = np.array([1.1, 2.2, 3.3])
- anan = np.array(np.nan)
- assert_raises(AssertionError, lambda: self._assert_func(x, anan))
- assert_raises(AssertionError, lambda: self._assert_func(anan, x))
- x = np.array([1.1, 2.2, np.nan])
- assert_raises(AssertionError, lambda: self._assert_func(x, anan))
- assert_raises(AssertionError, lambda: self._assert_func(anan, x))
- y = np.array([1.0, 2.0, np.nan])
- self._assert_func(y, x)
- assert_raises(AssertionError, lambda: self._assert_func(x, y))
- def test_inf_compare(self):
- aone = np.array(1)
- ainf = np.array(np.inf)
- self._assert_func(aone, ainf)
- self._assert_func(-ainf, aone)
- self._assert_func(-ainf, ainf)
- assert_raises(AssertionError, lambda: self._assert_func(ainf, aone))
- assert_raises(AssertionError, lambda: self._assert_func(aone, -ainf))
- assert_raises(AssertionError, lambda: self._assert_func(ainf, ainf))
- assert_raises(AssertionError, lambda: self._assert_func(ainf, -ainf))
- assert_raises(AssertionError, lambda: self._assert_func(-ainf, -ainf))
- def test_inf_compare_array(self):
- x = np.array([1.1, 2.2, np.inf])
- ainf = np.array(np.inf)
- assert_raises(AssertionError, lambda: self._assert_func(x, ainf))
- assert_raises(AssertionError, lambda: self._assert_func(ainf, x))
- assert_raises(AssertionError, lambda: self._assert_func(x, -ainf))
- assert_raises(AssertionError, lambda: self._assert_func(-x, -ainf))
- assert_raises(AssertionError, lambda: self._assert_func(-ainf, -x))
- self._assert_func(-ainf, x)
- def test_strict(self):
- """Test the behavior of the `strict` option."""
- x = np.zeros(3)
- y = np.ones(())
- self._assert_func(x, y)
- with pytest.raises(AssertionError):
- self._assert_func(x, y, strict=True)
- y = np.broadcast_to(y, x.shape)
- self._assert_func(x, y)
- with pytest.raises(AssertionError):
- self._assert_func(x, y.astype(np.float32), strict=True)
- @pytest.mark.filterwarnings(
- "ignore:.*NumPy warning suppression and assertion utilities are deprecated"
- ".*:DeprecationWarning")
- @pytest.mark.thread_unsafe(reason="checks global module & deprecated warnings")
- class TestWarns:
- def test_warn(self):
- def f():
- warnings.warn("yo")
- return 3
- before_filters = sys.modules['warnings'].filters[:]
- assert_equal(assert_warns(UserWarning, f), 3)
- after_filters = sys.modules['warnings'].filters
- assert_raises(AssertionError, assert_no_warnings, f)
- assert_equal(assert_no_warnings(lambda x: x, 1), 1)
- # Check that the warnings state is unchanged
- assert_equal(before_filters, after_filters,
- "assert_warns does not preserver warnings state")
- def test_context_manager(self):
- before_filters = sys.modules['warnings'].filters[:]
- with assert_warns(UserWarning):
- warnings.warn("yo")
- after_filters = sys.modules['warnings'].filters
- def no_warnings():
- with assert_no_warnings():
- warnings.warn("yo")
- assert_raises(AssertionError, no_warnings)
- assert_equal(before_filters, after_filters,
- "assert_warns does not preserver warnings state")
- def test_args(self):
- def f(a=0, b=1):
- warnings.warn("yo")
- return a + b
- assert assert_warns(UserWarning, f, b=20) == 20
- with pytest.raises(RuntimeError) as exc:
- # assert_warns cannot do regexp matching, use pytest.warns
- with assert_warns(UserWarning, match="A"):
- warnings.warn("B", UserWarning)
- assert "assert_warns" in str(exc)
- assert "pytest.warns" in str(exc)
- with pytest.raises(RuntimeError) as exc:
- # assert_warns cannot do regexp matching, use pytest.warns
- with assert_warns(UserWarning, wrong="A"):
- warnings.warn("B", UserWarning)
- assert "assert_warns" in str(exc)
- assert "pytest.warns" not in str(exc)
- def test_warn_wrong_warning(self):
- def f():
- warnings.warn("yo", DeprecationWarning)
- failed = False
- with warnings.catch_warnings():
- warnings.simplefilter("error", DeprecationWarning)
- try:
- # Should raise a DeprecationWarning
- assert_warns(UserWarning, f)
- failed = True
- except DeprecationWarning:
- pass
- if failed:
- raise AssertionError("wrong warning caught by assert_warn")
- class TestAssertAllclose:
- def test_simple(self):
- x = 1e-3
- y = 1e-9
- assert_allclose(x, y, atol=1)
- assert_raises(AssertionError, assert_allclose, x, y)
- expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
- 'Max absolute difference among violations: 0.001\n'
- 'Max relative difference among violations: 999999.')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- assert_allclose(x, y)
- z = 0
- expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
- 'Max absolute difference among violations: 1.e-09\n'
- 'Max relative difference among violations: inf')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- assert_allclose(y, z)
- expected_msg = ('Mismatched elements: 1 / 1 (100%)\n'
- 'Max absolute difference among violations: 1.e-09\n'
- 'Max relative difference among violations: 1.')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- assert_allclose(z, y)
- a = np.array([x, y, x, y])
- b = np.array([x, y, x, x])
- assert_allclose(a, b, atol=1)
- assert_raises(AssertionError, assert_allclose, a, b)
- b[-1] = y * (1 + 1e-8)
- assert_allclose(a, b)
- assert_raises(AssertionError, assert_allclose, a, b, rtol=1e-9)
- assert_allclose(6, 10, rtol=0.5)
- assert_raises(AssertionError, assert_allclose, 10, 6, rtol=0.5)
- b = np.array([x, y, x, x])
- c = np.array([x, y, x, z])
- expected_msg = ('Mismatched elements: 1 / 4 (25%)\n'
- 'Mismatch at index:\n'
- ' [3]: 0.001 (ACTUAL), 0.0 (DESIRED)\n'
- 'Max absolute difference among violations: 0.001\n'
- 'Max relative difference among violations: inf')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- assert_allclose(b, c)
- expected_msg = ('Mismatched elements: 1 / 4 (25%)\n'
- 'Mismatch at index:\n'
- ' [3]: 0.0 (ACTUAL), 0.001 (DESIRED)\n'
- 'Max absolute difference among violations: 0.001\n'
- 'Max relative difference among violations: 1.')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- assert_allclose(c, b)
- def test_min_int(self):
- a = np.array([np.iinfo(np.int_).min], dtype=np.int_)
- # Should not raise:
- assert_allclose(a, a)
- def test_report_fail_percentage(self):
- a = np.array([1, 1, 1, 1])
- b = np.array([1, 1, 1, 2])
- expected_msg = ('Mismatched elements: 1 / 4 (25%)\n'
- 'Mismatch at index:\n'
- ' [3]: 1 (ACTUAL), 2 (DESIRED)\n'
- 'Max absolute difference among violations: 1\n'
- 'Max relative difference among violations: 0.5')
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- assert_allclose(a, b)
- def test_equal_nan(self):
- a = np.array([np.nan])
- b = np.array([np.nan])
- # Should not raise:
- assert_allclose(a, b, equal_nan=True)
- a = np.array([complex(np.nan, np.inf)])
- b = np.array([complex(np.nan, np.inf)])
- assert_allclose(a, b, equal_nan=True)
- b = np.array([complex(np.nan, -np.inf)])
- assert_allclose(a, b, equal_nan=True)
- def test_not_equal_nan(self):
- a = np.array([np.nan])
- b = np.array([np.nan])
- assert_raises(AssertionError, assert_allclose, a, b, equal_nan=False)
- a = np.array([complex(np.nan, np.inf)])
- b = np.array([complex(np.nan, np.inf)])
- assert_raises(AssertionError, assert_allclose, a, b, equal_nan=False)
- def test_equal_nan_default(self):
- # Make sure equal_nan default behavior remains unchanged. (All
- # of these functions use assert_array_compare under the hood.)
- # None of these should raise.
- a = np.array([np.nan])
- b = np.array([np.nan])
- assert_array_equal(a, b)
- assert_array_almost_equal(a, b)
- assert_array_less(a, b)
- assert_allclose(a, b)
- def test_report_max_relative_error(self):
- a = np.array([0, 1])
- b = np.array([0, 2])
- expected_msg = 'Max relative difference among violations: 0.5'
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- assert_allclose(a, b)
- def test_timedelta(self):
- # see gh-18286
- a = np.array([[1, 2, 3, "NaT"]], dtype="m8[ns]")
- assert_allclose(a, a)
- def test_error_message_unsigned(self):
- """Check the message is formatted correctly when overflow can occur
- (gh21768)"""
- # Ensure to test for potential overflow in the case of:
- # x - y
- # and
- # y - x
- x = np.asarray([0, 1, 8], dtype='uint8')
- y = np.asarray([4, 4, 4], dtype='uint8')
- expected_msg = 'Max absolute difference among violations: 4'
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- assert_allclose(x, y, atol=3)
- def test_strict(self):
- """Test the behavior of the `strict` option."""
- x = np.ones(3)
- y = np.ones(())
- assert_allclose(x, y)
- with pytest.raises(AssertionError):
- assert_allclose(x, y, strict=True)
- assert_allclose(x, x)
- with pytest.raises(AssertionError):
- assert_allclose(x, x.astype(np.float32), strict=True)
- def test_infs(self):
- a = np.array([np.inf])
- b = np.array([np.inf])
- assert_allclose(a, b)
- b = np.array([3.])
- expected_msg = 'inf location mismatch:'
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- assert_allclose(a, b)
- b = np.array([-np.inf])
- expected_msg = 'inf values mismatch:'
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- assert_allclose(a, b)
- b = np.array([complex(np.inf, 1.)])
- expected_msg = 'inf values mismatch:'
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- assert_allclose(a, b)
- a = np.array([complex(np.inf, 1.)])
- b = np.array([complex(np.inf, 1.)])
- assert_allclose(a, b)
- b = np.array([complex(np.inf, 2.)])
- expected_msg = 'inf values mismatch:'
- with pytest.raises(AssertionError, match=re.escape(expected_msg)):
- assert_allclose(a, b)
- class TestArrayAlmostEqualNulp:
- def test_float64_pass(self):
- # The number of units of least precision
- # In this case, use a few places above the lowest level (ie nulp=1)
- nulp = 5
- x = np.linspace(-20, 20, 50, dtype=np.float64)
- x = 10**x
- x = np.r_[-x, x]
- # Addition
- eps = np.finfo(x.dtype).eps
- y = x + x * eps * nulp / 2.
- assert_array_almost_equal_nulp(x, y, nulp)
- # Subtraction
- epsneg = np.finfo(x.dtype).epsneg
- y = x - x * epsneg * nulp / 2.
- assert_array_almost_equal_nulp(x, y, nulp)
- def test_float64_fail(self):
- nulp = 5
- x = np.linspace(-20, 20, 50, dtype=np.float64)
- x = 10**x
- x = np.r_[-x, x]
- eps = np.finfo(x.dtype).eps
- y = x + x * eps * nulp * 2.
- assert_raises(AssertionError, assert_array_almost_equal_nulp,
- x, y, nulp)
- epsneg = np.finfo(x.dtype).epsneg
- y = x - x * epsneg * nulp * 2.
- assert_raises(AssertionError, assert_array_almost_equal_nulp,
- x, y, nulp)
- def test_float64_ignore_nan(self):
- # Ignore ULP differences between various NAN's
- # Note that MIPS may reverse quiet and signaling nans
- # so we use the builtin version as a base.
- offset = np.uint64(0xffffffff)
- nan1_i64 = np.array(np.nan, dtype=np.float64).view(np.uint64)
- nan2_i64 = nan1_i64 ^ offset # nan payload on MIPS is all ones.
- nan1_f64 = nan1_i64.view(np.float64)
- nan2_f64 = nan2_i64.view(np.float64)
- assert_array_max_ulp(nan1_f64, nan2_f64, 0)
- def test_float32_pass(self):
- nulp = 5
- x = np.linspace(-20, 20, 50, dtype=np.float32)
- x = 10**x
- x = np.r_[-x, x]
- eps = np.finfo(x.dtype).eps
- y = x + x * eps * nulp / 2.
- assert_array_almost_equal_nulp(x, y, nulp)
- epsneg = np.finfo(x.dtype).epsneg
- y = x - x * epsneg * nulp / 2.
- assert_array_almost_equal_nulp(x, y, nulp)
- def test_float32_fail(self):
- nulp = 5
- x = np.linspace(-20, 20, 50, dtype=np.float32)
- x = 10**x
- x = np.r_[-x, x]
- eps = np.finfo(x.dtype).eps
- y = x + x * eps * nulp * 2.
- assert_raises(AssertionError, assert_array_almost_equal_nulp,
- x, y, nulp)
- epsneg = np.finfo(x.dtype).epsneg
- y = x - x * epsneg * nulp * 2.
- assert_raises(AssertionError, assert_array_almost_equal_nulp,
- x, y, nulp)
- def test_float32_ignore_nan(self):
- # Ignore ULP differences between various NAN's
- # Note that MIPS may reverse quiet and signaling nans
- # so we use the builtin version as a base.
- offset = np.uint32(0xffff)
- nan1_i32 = np.array(np.nan, dtype=np.float32).view(np.uint32)
- nan2_i32 = nan1_i32 ^ offset # nan payload on MIPS is all ones.
- nan1_f32 = nan1_i32.view(np.float32)
- nan2_f32 = nan2_i32.view(np.float32)
- assert_array_max_ulp(nan1_f32, nan2_f32, 0)
- def test_float16_pass(self):
- nulp = 5
- x = np.linspace(-4, 4, 10, dtype=np.float16)
- x = 10**x
- x = np.r_[-x, x]
- eps = np.finfo(x.dtype).eps
- y = x + x * eps * nulp / 2.
- assert_array_almost_equal_nulp(x, y, nulp)
- epsneg = np.finfo(x.dtype).epsneg
- y = x - x * epsneg * nulp / 2.
- assert_array_almost_equal_nulp(x, y, nulp)
- def test_float16_fail(self):
- nulp = 5
- x = np.linspace(-4, 4, 10, dtype=np.float16)
- x = 10**x
- x = np.r_[-x, x]
- eps = np.finfo(x.dtype).eps
- y = x + x * eps * nulp * 2.
- assert_raises(AssertionError, assert_array_almost_equal_nulp,
- x, y, nulp)
- epsneg = np.finfo(x.dtype).epsneg
- y = x - x * epsneg * nulp * 2.
- assert_raises(AssertionError, assert_array_almost_equal_nulp,
- x, y, nulp)
- def test_float16_ignore_nan(self):
- # Ignore ULP differences between various NAN's
- # Note that MIPS may reverse quiet and signaling nans
- # so we use the builtin version as a base.
- offset = np.uint16(0xff)
- nan1_i16 = np.array(np.nan, dtype=np.float16).view(np.uint16)
- nan2_i16 = nan1_i16 ^ offset # nan payload on MIPS is all ones.
- nan1_f16 = nan1_i16.view(np.float16)
- nan2_f16 = nan2_i16.view(np.float16)
- assert_array_max_ulp(nan1_f16, nan2_f16, 0)
- def test_complex128_pass(self):
- nulp = 5
- x = np.linspace(-20, 20, 50, dtype=np.float64)
- x = 10**x
- x = np.r_[-x, x]
- xi = x + x * 1j
- eps = np.finfo(x.dtype).eps
- y = x + x * eps * nulp / 2.
- assert_array_almost_equal_nulp(xi, x + y * 1j, nulp)
- assert_array_almost_equal_nulp(xi, y + x * 1j, nulp)
- # The test condition needs to be at least a factor of sqrt(2) smaller
- # because the real and imaginary parts both change
- y = x + x * eps * nulp / 4.
- assert_array_almost_equal_nulp(xi, y + y * 1j, nulp)
- epsneg = np.finfo(x.dtype).epsneg
- y = x - x * epsneg * nulp / 2.
- assert_array_almost_equal_nulp(xi, x + y * 1j, nulp)
- assert_array_almost_equal_nulp(xi, y + x * 1j, nulp)
- y = x - x * epsneg * nulp / 4.
- assert_array_almost_equal_nulp(xi, y + y * 1j, nulp)
- def test_complex128_fail(self):
- nulp = 5
- x = np.linspace(-20, 20, 50, dtype=np.float64)
- x = 10**x
- x = np.r_[-x, x]
- xi = x + x * 1j
- eps = np.finfo(x.dtype).eps
- y = x + x * eps * nulp * 2.
- assert_raises(AssertionError, assert_array_almost_equal_nulp,
- xi, x + y * 1j, nulp)
- assert_raises(AssertionError, assert_array_almost_equal_nulp,
- xi, y + x * 1j, nulp)
- # The test condition needs to be at least a factor of sqrt(2) smaller
- # because the real and imaginary parts both change
- y = x + x * eps * nulp
- assert_raises(AssertionError, assert_array_almost_equal_nulp,
- xi, y + y * 1j, nulp)
- epsneg = np.finfo(x.dtype).epsneg
- y = x - x * epsneg * nulp * 2.
- assert_raises(AssertionError, assert_array_almost_equal_nulp,
- xi, x + y * 1j, nulp)
- assert_raises(AssertionError, assert_array_almost_equal_nulp,
- xi, y + x * 1j, nulp)
- y = x - x * epsneg * nulp
- assert_raises(AssertionError, assert_array_almost_equal_nulp,
- xi, y + y * 1j, nulp)
- def test_complex64_pass(self):
- nulp = 5
- x = np.linspace(-20, 20, 50, dtype=np.float32)
- x = 10**x
- x = np.r_[-x, x]
- xi = x + x * 1j
- eps = np.finfo(x.dtype).eps
- y = x + x * eps * nulp / 2.
- assert_array_almost_equal_nulp(xi, x + y * 1j, nulp)
- assert_array_almost_equal_nulp(xi, y + x * 1j, nulp)
- y = x + x * eps * nulp / 4.
- assert_array_almost_equal_nulp(xi, y + y * 1j, nulp)
- epsneg = np.finfo(x.dtype).epsneg
- y = x - x * epsneg * nulp / 2.
- assert_array_almost_equal_nulp(xi, x + y * 1j, nulp)
- assert_array_almost_equal_nulp(xi, y + x * 1j, nulp)
- y = x - x * epsneg * nulp / 4.
- assert_array_almost_equal_nulp(xi, y + y * 1j, nulp)
- def test_complex64_fail(self):
- nulp = 5
- x = np.linspace(-20, 20, 50, dtype=np.float32)
- x = 10**x
- x = np.r_[-x, x]
- xi = x + x * 1j
- eps = np.finfo(x.dtype).eps
- y = x + x * eps * nulp * 2.
- assert_raises(AssertionError, assert_array_almost_equal_nulp,
- xi, x + y * 1j, nulp)
- assert_raises(AssertionError, assert_array_almost_equal_nulp,
- xi, y + x * 1j, nulp)
- y = x + x * eps * nulp
- assert_raises(AssertionError, assert_array_almost_equal_nulp,
- xi, y + y * 1j, nulp)
- epsneg = np.finfo(x.dtype).epsneg
- y = x - x * epsneg * nulp * 2.
- assert_raises(AssertionError, assert_array_almost_equal_nulp,
- xi, x + y * 1j, nulp)
- assert_raises(AssertionError, assert_array_almost_equal_nulp,
- xi, y + x * 1j, nulp)
- y = x - x * epsneg * nulp
- assert_raises(AssertionError, assert_array_almost_equal_nulp,
- xi, y + y * 1j, nulp)
- class TestULP:
- def test_equal(self):
- x = np.random.randn(10)
- assert_array_max_ulp(x, x, maxulp=0)
- def test_single(self):
- # Generate 1 + small deviation, check that adding eps gives a few UNL
- x = np.ones(10).astype(np.float32)
- x += 0.01 * np.random.randn(10).astype(np.float32)
- eps = np.finfo(np.float32).eps
- assert_array_max_ulp(x, x + eps, maxulp=20)
- def test_double(self):
- # Generate 1 + small deviation, check that adding eps gives a few UNL
- x = np.ones(10).astype(np.float64)
- x += 0.01 * np.random.randn(10).astype(np.float64)
- eps = np.finfo(np.float64).eps
- assert_array_max_ulp(x, x + eps, maxulp=200)
- def test_inf(self):
- for dt in [np.float32, np.float64]:
- inf = np.array([np.inf]).astype(dt)
- big = np.array([np.finfo(dt).max])
- assert_array_max_ulp(inf, big, maxulp=200)
- def test_nan(self):
- # Test that nan is 'far' from small, tiny, inf, max and min
- for dt in [np.float32, np.float64]:
- if dt == np.float32:
- maxulp = 1e6
- else:
- maxulp = 1e12
- inf = np.array([np.inf]).astype(dt)
- nan = np.array([np.nan]).astype(dt)
- big = np.array([np.finfo(dt).max])
- tiny = np.array([np.finfo(dt).tiny])
- zero = np.array([0.0]).astype(dt)
- nzero = np.array([-0.0]).astype(dt)
- assert_raises(AssertionError,
- lambda: assert_array_max_ulp(nan, inf,
- maxulp=maxulp))
- assert_raises(AssertionError,
- lambda: assert_array_max_ulp(nan, big,
- maxulp=maxulp))
- assert_raises(AssertionError,
- lambda: assert_array_max_ulp(nan, tiny,
- maxulp=maxulp))
- assert_raises(AssertionError,
- lambda: assert_array_max_ulp(nan, zero,
- maxulp=maxulp))
- assert_raises(AssertionError,
- lambda: assert_array_max_ulp(nan, nzero,
- maxulp=maxulp))
- class TestStringEqual:
- def test_simple(self):
- assert_string_equal("hello", "hello")
- assert_string_equal("hello\nmultiline", "hello\nmultiline")
- with pytest.raises(AssertionError) as exc_info:
- assert_string_equal("foo\nbar", "hello\nbar")
- msg = str(exc_info.value)
- assert_equal(msg, "Differences in strings:\n- foo\n+ hello")
- assert_raises(AssertionError,
- lambda: assert_string_equal("foo", "hello"))
- def test_regex(self):
- assert_string_equal("a+*b", "a+*b")
- assert_raises(AssertionError,
- lambda: assert_string_equal("aaa", "a+b"))
- def assert_warn_len_equal(mod, n_in_context):
- try:
- mod_warns = mod.__warningregistry__
- except AttributeError:
- # the lack of a __warningregistry__
- # attribute means that no warning has
- # occurred; this can be triggered in
- # a parallel test scenario, while in
- # a serial test scenario an initial
- # warning (and therefore the attribute)
- # are always created first
- mod_warns = {}
- num_warns = len(mod_warns)
- if 'version' in mod_warns:
- # Python adds a 'version' entry to the registry,
- # do not count it.
- num_warns -= 1
- assert_equal(num_warns, n_in_context)
- def test_warn_len_equal_call_scenarios():
- # assert_warn_len_equal is called under
- # varying circumstances depending on serial
- # vs. parallel test scenarios; this test
- # simply aims to probe both code paths and
- # check that no assertion is uncaught
- # parallel scenario -- no warning issued yet
- class mod:
- pass
- mod_inst = mod()
- assert_warn_len_equal(mod=mod_inst,
- n_in_context=0)
- # serial test scenario -- the __warningregistry__
- # attribute should be present
- class mod:
- def __init__(self):
- self.__warningregistry__ = {'warning1': 1,
- 'warning2': 2}
- mod_inst = mod()
- assert_warn_len_equal(mod=mod_inst,
- n_in_context=2)
- def _get_fresh_mod():
- # Get this module, with warning registry empty
- my_mod = sys.modules[__name__]
- try:
- my_mod.__warningregistry__.clear()
- except AttributeError:
- # will not have a __warningregistry__ unless warning has been
- # raised in the module at some point
- pass
- return my_mod
- @pytest.mark.thread_unsafe(reason="checks global module & deprecated warnings")
- def test_clear_and_catch_warnings():
- # Initial state of module, no warnings
- my_mod = _get_fresh_mod()
- assert_equal(getattr(my_mod, '__warningregistry__', {}), {})
- with clear_and_catch_warnings(modules=[my_mod]):
- warnings.simplefilter('ignore')
- warnings.warn('Some warning')
- assert_equal(my_mod.__warningregistry__, {})
- # Without specified modules, don't clear warnings during context.
- # catch_warnings doesn't make an entry for 'ignore'.
- with clear_and_catch_warnings():
- warnings.simplefilter('ignore')
- warnings.warn('Some warning')
- assert_warn_len_equal(my_mod, 0)
- # Manually adding two warnings to the registry:
- my_mod.__warningregistry__ = {'warning1': 1,
- 'warning2': 2}
- # Confirm that specifying module keeps old warning, does not add new
- with clear_and_catch_warnings(modules=[my_mod]):
- warnings.simplefilter('ignore')
- warnings.warn('Another warning')
- assert_warn_len_equal(my_mod, 2)
- # Another warning, no module spec it clears up registry
- with clear_and_catch_warnings():
- warnings.simplefilter('ignore')
- warnings.warn('Another warning')
- assert_warn_len_equal(my_mod, 0)
- @pytest.mark.filterwarnings(
- "ignore:.*NumPy warning suppression and assertion utilities are deprecated"
- ".*:DeprecationWarning")
- @pytest.mark.thread_unsafe(reason="checks global module & deprecated warnings")
- def test_suppress_warnings_module():
- # Initial state of module, no warnings
- my_mod = _get_fresh_mod()
- assert_equal(getattr(my_mod, '__warningregistry__', {}), {})
- def warn_other_module():
- # Apply along axis is implemented in python; stacklevel=2 means
- # we end up inside its module, not ours.
- def warn(arr):
- warnings.warn("Some warning 2", stacklevel=2)
- return arr
- np.apply_along_axis(warn, 0, [0])
- # Test module based warning suppression:
- assert_warn_len_equal(my_mod, 0)
- with suppress_warnings() as sup:
- sup.record(UserWarning)
- # suppress warning from other module (may have .pyc ending),
- # if apply_along_axis is moved, had to be changed.
- sup.filter(module=np.lib._shape_base_impl)
- warnings.warn("Some warning")
- warn_other_module()
- # Check that the suppression did test the file correctly (this module
- # got filtered)
- assert_equal(len(sup.log), 1)
- assert_equal(sup.log[0].message.args[0], "Some warning")
- assert_warn_len_equal(my_mod, 0)
- sup = suppress_warnings()
- # Will have to be changed if apply_along_axis is moved:
- sup.filter(module=my_mod)
- with sup:
- warnings.warn('Some warning')
- assert_warn_len_equal(my_mod, 0)
- # And test repeat works:
- sup.filter(module=my_mod)
- with sup:
- warnings.warn('Some warning')
- assert_warn_len_equal(my_mod, 0)
- # Without specified modules
- with suppress_warnings():
- warnings.simplefilter('ignore')
- warnings.warn('Some warning')
- assert_warn_len_equal(my_mod, 0)
- @pytest.mark.filterwarnings(
- "ignore:.*NumPy warning suppression and assertion utilities are deprecated"
- ".*:DeprecationWarning")
- @pytest.mark.thread_unsafe(reason="checks global module & deprecated warnings")
- def test_suppress_warnings_type():
- # Initial state of module, no warnings
- my_mod = _get_fresh_mod()
- assert_equal(getattr(my_mod, '__warningregistry__', {}), {})
- # Test module based warning suppression:
- with suppress_warnings() as sup:
- sup.filter(UserWarning)
- warnings.warn('Some warning')
- assert_warn_len_equal(my_mod, 0)
- sup = suppress_warnings()
- sup.filter(UserWarning)
- with sup:
- warnings.warn('Some warning')
- assert_warn_len_equal(my_mod, 0)
- # And test repeat works:
- sup.filter(module=my_mod)
- with sup:
- warnings.warn('Some warning')
- assert_warn_len_equal(my_mod, 0)
- # Without specified modules
- with suppress_warnings():
- warnings.simplefilter('ignore')
- warnings.warn('Some warning')
- assert_warn_len_equal(my_mod, 0)
- @pytest.mark.filterwarnings(
- "ignore:.*NumPy warning suppression and assertion utilities are deprecated"
- ".*:DeprecationWarning")
- @pytest.mark.thread_unsafe(
- reason="uses deprecated thread-unsafe warnings control utilities"
- )
- def test_suppress_warnings_decorate_no_record():
- sup = suppress_warnings()
- sup.filter(UserWarning)
- @sup
- def warn(category):
- warnings.warn('Some warning', category)
- with warnings.catch_warnings(record=True) as w:
- warnings.simplefilter("always")
- warn(UserWarning) # should be suppressed
- warn(RuntimeWarning)
- assert_equal(len(w), 1)
- @pytest.mark.filterwarnings(
- "ignore:.*NumPy warning suppression and assertion utilities are deprecated"
- ".*:DeprecationWarning")
- @pytest.mark.thread_unsafe(
- reason="uses deprecated thread-unsafe warnings control utilities"
- )
- def test_suppress_warnings_record():
- sup = suppress_warnings()
- log1 = sup.record()
- with sup:
- log2 = sup.record(message='Some other warning 2')
- sup.filter(message='Some warning')
- warnings.warn('Some warning')
- warnings.warn('Some other warning')
- warnings.warn('Some other warning 2')
- assert_equal(len(sup.log), 2)
- assert_equal(len(log1), 1)
- assert_equal(len(log2), 1)
- assert_equal(log2[0].message.args[0], 'Some other warning 2')
- # Do it again, with the same context to see if some warnings survived:
- with sup:
- log2 = sup.record(message='Some other warning 2')
- sup.filter(message='Some warning')
- warnings.warn('Some warning')
- warnings.warn('Some other warning')
- warnings.warn('Some other warning 2')
- assert_equal(len(sup.log), 2)
- assert_equal(len(log1), 1)
- assert_equal(len(log2), 1)
- assert_equal(log2[0].message.args[0], 'Some other warning 2')
- # Test nested:
- with suppress_warnings() as sup:
- sup.record()
- with suppress_warnings() as sup2:
- sup2.record(message='Some warning')
- warnings.warn('Some warning')
- warnings.warn('Some other warning')
- assert_equal(len(sup2.log), 1)
- # includes a DeprecationWarning for suppress_warnings
- assert_equal(len(sup.log), 2)
- @pytest.mark.filterwarnings(
- "ignore:.*NumPy warning suppression and assertion utilities are deprecated"
- ".*:DeprecationWarning")
- @pytest.mark.thread_unsafe(
- reason="uses deprecated thread-unsafe warnings control utilities"
- )
- def test_suppress_warnings_forwarding():
- def warn_other_module():
- # Apply along axis is implemented in python; stacklevel=2 means
- # we end up inside its module, not ours.
- def warn(arr):
- warnings.warn("Some warning", stacklevel=2)
- return arr
- np.apply_along_axis(warn, 0, [0])
- with suppress_warnings() as sup:
- sup.record()
- with suppress_warnings("always"):
- for i in range(2):
- warnings.warn("Some warning")
- # includes a DeprecationWarning for suppress_warnings
- assert_equal(len(sup.log), 3)
- with suppress_warnings() as sup:
- sup.record()
- with suppress_warnings("location"):
- for i in range(2):
- warnings.warn("Some warning")
- warnings.warn("Some warning")
- # includes a DeprecationWarning for suppress_warnings
- assert_equal(len(sup.log), 3)
- with suppress_warnings() as sup:
- sup.record()
- with suppress_warnings("module"):
- for i in range(2):
- warnings.warn("Some warning")
- warnings.warn("Some warning")
- warn_other_module()
- # includes a DeprecationWarning for suppress_warnings
- assert_equal(len(sup.log), 3)
- with suppress_warnings() as sup:
- sup.record()
- with suppress_warnings("once"):
- for i in range(2):
- warnings.warn("Some warning")
- warnings.warn("Some other warning")
- warn_other_module()
- # includes a DeprecationWarning for suppress_warnings
- assert_equal(len(sup.log), 3)
- def test_tempdir():
- with tempdir() as tdir:
- fpath = os.path.join(tdir, 'tmp')
- with open(fpath, 'w'):
- pass
- assert_(not os.path.isdir(tdir))
- raised = False
- try:
- with tempdir() as tdir:
- raise ValueError
- except ValueError:
- raised = True
- assert_(raised)
- assert_(not os.path.isdir(tdir))
- def test_temppath():
- with temppath() as fpath:
- with open(fpath, 'w'):
- pass
- assert_(not os.path.isfile(fpath))
- raised = False
- try:
- with temppath() as fpath:
- raise ValueError
- except ValueError:
- raised = True
- assert_(raised)
- assert_(not os.path.isfile(fpath))
- class my_cacw(clear_and_catch_warnings):
- class_modules = (sys.modules[__name__],)
- @pytest.mark.thread_unsafe(reason="checks global module & deprecated warnings")
- def test_clear_and_catch_warnings_inherit():
- # Test can subclass and add default modules
- my_mod = _get_fresh_mod()
- with my_cacw():
- warnings.simplefilter('ignore')
- warnings.warn('Some warning')
- assert_equal(my_mod.__warningregistry__, {})
- @pytest.mark.skipif(not HAS_REFCOUNT, reason="Python lacks refcounts")
- @pytest.mark.thread_unsafe(reason="garbage collector is global state")
- class TestAssertNoGcCycles:
- """ Test assert_no_gc_cycles """
- def test_passes(self):
- def no_cycle():
- b = []
- b.append([])
- return b
- with assert_no_gc_cycles():
- no_cycle()
- assert_no_gc_cycles(no_cycle)
- def test_asserts(self):
- def make_cycle():
- a = []
- a.append(a)
- a.append(a)
- return a
- with assert_raises(AssertionError):
- with assert_no_gc_cycles():
- make_cycle()
- with assert_raises(AssertionError):
- assert_no_gc_cycles(make_cycle)
- @pytest.mark.slow
- def test_fails(self):
- """
- Test that in cases where the garbage cannot be collected, we raise an
- error, instead of hanging forever trying to clear it.
- """
- class ReferenceCycleInDel:
- """
- An object that not only contains a reference cycle, but creates new
- cycles whenever it's garbage-collected and its __del__ runs
- """
- make_cycle = True
- def __init__(self):
- self.cycle = self
- def __del__(self):
- # break the current cycle so that `self` can be freed
- self.cycle = None
- if ReferenceCycleInDel.make_cycle:
- # but create a new one so that the garbage collector (GC) has more
- # work to do.
- ReferenceCycleInDel()
- try:
- w = weakref.ref(ReferenceCycleInDel())
- try:
- with assert_raises(RuntimeError):
- # this will be unable to get a baseline empty garbage
- assert_no_gc_cycles(lambda: None)
- except AssertionError:
- # the above test is only necessary if the GC actually tried to free
- # our object anyway.
- if w() is not None:
- pytest.skip("GC does not call __del__ on cyclic objects")
- raise
- finally:
- # make sure that we stop creating reference cycles
- ReferenceCycleInDel.make_cycle = False
|