testutils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. """Miscellaneous functions for testing masked arrays and subclasses
  2. :author: Pierre Gerard-Marchant
  3. :contact: pierregm_at_uga_dot_edu
  4. :version: $Id: testutils.py 3529 2007-11-13 08:01:14Z jarrod.millman $
  5. """
  6. import operator
  7. import numpy as np
  8. from numpy import ndarray
  9. import numpy._core.umath as umath
  10. import numpy.testing
  11. from numpy.testing import (
  12. assert_, assert_allclose, assert_array_almost_equal_nulp,
  13. assert_raises, build_err_msg
  14. )
  15. from .core import mask_or, getmask, masked_array, nomask, masked, filled
  16. __all__masked = [
  17. 'almost', 'approx', 'assert_almost_equal', 'assert_array_almost_equal',
  18. 'assert_array_approx_equal', 'assert_array_compare',
  19. 'assert_array_equal', 'assert_array_less', 'assert_close',
  20. 'assert_equal', 'assert_equal_records', 'assert_mask_equal',
  21. 'assert_not_equal', 'fail_if_array_equal',
  22. ]
  23. # Include some normal test functions to avoid breaking other projects who
  24. # have mistakenly included them from this file. SciPy is one. That is
  25. # unfortunate, as some of these functions are not intended to work with
  26. # masked arrays. But there was no way to tell before.
  27. from unittest import TestCase
  28. __some__from_testing = [
  29. 'TestCase', 'assert_', 'assert_allclose', 'assert_array_almost_equal_nulp',
  30. 'assert_raises'
  31. ]
  32. __all__ = __all__masked + __some__from_testing
  33. def approx(a, b, fill_value=True, rtol=1e-5, atol=1e-8):
  34. """
  35. Returns true if all components of a and b are equal to given tolerances.
  36. If fill_value is True, masked values considered equal. Otherwise,
  37. masked values are considered unequal. The relative error rtol should
  38. be positive and << 1.0 The absolute error atol comes into play for
  39. those elements of b that are very small or zero; it says how small a
  40. must be also.
  41. """
  42. m = mask_or(getmask(a), getmask(b))
  43. d1 = filled(a)
  44. d2 = filled(b)
  45. if d1.dtype.char == "O" or d2.dtype.char == "O":
  46. return np.equal(d1, d2).ravel()
  47. x = filled(
  48. masked_array(d1, copy=False, mask=m), fill_value
  49. ).astype(np.float64)
  50. y = filled(masked_array(d2, copy=False, mask=m), 1).astype(np.float64)
  51. d = np.less_equal(umath.absolute(x - y), atol + rtol * umath.absolute(y))
  52. return d.ravel()
  53. def almost(a, b, decimal=6, fill_value=True):
  54. """
  55. Returns True if a and b are equal up to decimal places.
  56. If fill_value is True, masked values considered equal. Otherwise,
  57. masked values are considered unequal.
  58. """
  59. m = mask_or(getmask(a), getmask(b))
  60. d1 = filled(a)
  61. d2 = filled(b)
  62. if d1.dtype.char == "O" or d2.dtype.char == "O":
  63. return np.equal(d1, d2).ravel()
  64. x = filled(
  65. masked_array(d1, copy=False, mask=m), fill_value
  66. ).astype(np.float64)
  67. y = filled(masked_array(d2, copy=False, mask=m), 1).astype(np.float64)
  68. d = np.around(np.abs(x - y), decimal) <= 10.0 ** (-decimal)
  69. return d.ravel()
  70. def _assert_equal_on_sequences(actual, desired, err_msg=''):
  71. """
  72. Asserts the equality of two non-array sequences.
  73. """
  74. assert_equal(len(actual), len(desired), err_msg)
  75. for k in range(len(desired)):
  76. assert_equal(actual[k], desired[k], f'item={k!r}\n{err_msg}')
  77. return
  78. def assert_equal_records(a, b):
  79. """
  80. Asserts that two records are equal.
  81. Pretty crude for now.
  82. """
  83. assert_equal(a.dtype, b.dtype)
  84. for f in a.dtype.names:
  85. (af, bf) = (operator.getitem(a, f), operator.getitem(b, f))
  86. if not (af is masked) and not (bf is masked):
  87. assert_equal(operator.getitem(a, f), operator.getitem(b, f))
  88. return
  89. def assert_equal(actual, desired, err_msg=''):
  90. """
  91. Asserts that two items are equal.
  92. """
  93. # Case #1: dictionary .....
  94. if isinstance(desired, dict):
  95. if not isinstance(actual, dict):
  96. raise AssertionError(repr(type(actual)))
  97. assert_equal(len(actual), len(desired), err_msg)
  98. for k, i in desired.items():
  99. if k not in actual:
  100. raise AssertionError(f"{k} not in {actual}")
  101. assert_equal(actual[k], desired[k], f'key={k!r}\n{err_msg}')
  102. return
  103. # Case #2: lists .....
  104. if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)):
  105. return _assert_equal_on_sequences(actual, desired, err_msg='')
  106. if not (isinstance(actual, ndarray) or isinstance(desired, ndarray)):
  107. msg = build_err_msg([actual, desired], err_msg,)
  108. if not desired == actual:
  109. raise AssertionError(msg)
  110. return
  111. # Case #4. arrays or equivalent
  112. if ((actual is masked) and not (desired is masked)) or \
  113. ((desired is masked) and not (actual is masked)):
  114. msg = build_err_msg([actual, desired],
  115. err_msg, header='', names=('x', 'y'))
  116. raise ValueError(msg)
  117. actual = np.asanyarray(actual)
  118. desired = np.asanyarray(desired)
  119. (actual_dtype, desired_dtype) = (actual.dtype, desired.dtype)
  120. if actual_dtype.char == "S" and desired_dtype.char == "S":
  121. return _assert_equal_on_sequences(actual.tolist(),
  122. desired.tolist(),
  123. err_msg='')
  124. return assert_array_equal(actual, desired, err_msg)
  125. def fail_if_equal(actual, desired, err_msg='',):
  126. """
  127. Raises an assertion error if two items are equal.
  128. """
  129. if isinstance(desired, dict):
  130. if not isinstance(actual, dict):
  131. raise AssertionError(repr(type(actual)))
  132. fail_if_equal(len(actual), len(desired), err_msg)
  133. for k, i in desired.items():
  134. if k not in actual:
  135. raise AssertionError(repr(k))
  136. fail_if_equal(actual[k], desired[k], f'key={k!r}\n{err_msg}')
  137. return
  138. if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)):
  139. fail_if_equal(len(actual), len(desired), err_msg)
  140. for k in range(len(desired)):
  141. fail_if_equal(actual[k], desired[k], f'item={k!r}\n{err_msg}')
  142. return
  143. if isinstance(actual, np.ndarray) or isinstance(desired, np.ndarray):
  144. return fail_if_array_equal(actual, desired, err_msg)
  145. msg = build_err_msg([actual, desired], err_msg)
  146. if not desired != actual:
  147. raise AssertionError(msg)
  148. assert_not_equal = fail_if_equal
  149. def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True):
  150. """
  151. Asserts that two items are almost equal.
  152. The test is equivalent to abs(desired-actual) < 0.5 * 10**(-decimal).
  153. """
  154. if isinstance(actual, np.ndarray) or isinstance(desired, np.ndarray):
  155. return assert_array_almost_equal(actual, desired, decimal=decimal,
  156. err_msg=err_msg, verbose=verbose)
  157. msg = build_err_msg([actual, desired],
  158. err_msg=err_msg, verbose=verbose)
  159. if not round(abs(desired - actual), decimal) == 0:
  160. raise AssertionError(msg)
  161. assert_close = assert_almost_equal
  162. def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
  163. fill_value=True):
  164. """
  165. Asserts that comparison between two masked arrays is satisfied.
  166. The comparison is elementwise.
  167. """
  168. # Allocate a common mask and refill
  169. m = mask_or(getmask(x), getmask(y))
  170. x = masked_array(x, copy=False, mask=m, keep_mask=False, subok=False)
  171. y = masked_array(y, copy=False, mask=m, keep_mask=False, subok=False)
  172. if ((x is masked) and not (y is masked)) or \
  173. ((y is masked) and not (x is masked)):
  174. msg = build_err_msg([x, y], err_msg=err_msg, verbose=verbose,
  175. header=header, names=('x', 'y'))
  176. raise ValueError(msg)
  177. # OK, now run the basic tests on filled versions
  178. return np.testing.assert_array_compare(comparison,
  179. x.filled(fill_value),
  180. y.filled(fill_value),
  181. err_msg=err_msg,
  182. verbose=verbose, header=header)
  183. def assert_array_equal(x, y, err_msg='', verbose=True):
  184. """
  185. Checks the elementwise equality of two masked arrays.
  186. """
  187. assert_array_compare(operator.__eq__, x, y,
  188. err_msg=err_msg, verbose=verbose,
  189. header='Arrays are not equal')
  190. def fail_if_array_equal(x, y, err_msg='', verbose=True):
  191. """
  192. Raises an assertion error if two masked arrays are not equal elementwise.
  193. """
  194. def compare(x, y):
  195. return (not np.all(approx(x, y)))
  196. assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
  197. header='Arrays are not equal')
  198. def assert_array_approx_equal(x, y, decimal=6, err_msg='', verbose=True):
  199. """
  200. Checks the equality of two masked arrays, up to given number odecimals.
  201. The equality is checked elementwise.
  202. """
  203. def compare(x, y):
  204. "Returns the result of the loose comparison between x and y)."
  205. return approx(x, y, rtol=10. ** -decimal)
  206. assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
  207. header='Arrays are not almost equal')
  208. def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
  209. """
  210. Checks the equality of two masked arrays, up to given number odecimals.
  211. The equality is checked elementwise.
  212. """
  213. def compare(x, y):
  214. "Returns the result of the loose comparison between x and y)."
  215. return almost(x, y, decimal)
  216. assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
  217. header='Arrays are not almost equal')
  218. def assert_array_less(x, y, err_msg='', verbose=True):
  219. """
  220. Checks that x is smaller than y elementwise.
  221. """
  222. assert_array_compare(operator.__lt__, x, y,
  223. err_msg=err_msg, verbose=verbose,
  224. header='Arrays are not less-ordered')
  225. def assert_mask_equal(m1, m2, err_msg=''):
  226. """
  227. Asserts the equality of two masks.
  228. """
  229. if m1 is nomask:
  230. assert_(m2 is nomask)
  231. if m2 is nomask:
  232. assert_(m1 is nomask)
  233. assert_array_equal(m1, m2, err_msg=err_msg)