_testutils.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import numpy as np
  2. class _FakeMatrix:
  3. def __init__(self, data):
  4. self._data = data
  5. self.__array_interface__ = data.__array_interface__
  6. class _FakeMatrix2:
  7. def __init__(self, data):
  8. self._data = data
  9. def __array__(self, dtype=None, copy=None):
  10. if copy:
  11. return self._data.copy()
  12. return self._data
  13. def _get_array(shape, dtype):
  14. """
  15. Get a test array of given shape and data type.
  16. Returned NxN matrices are posdef, and 2xN are banded-posdef.
  17. """
  18. if len(shape) == 2 and shape[0] == 2:
  19. # yield a banded positive definite one
  20. x = np.zeros(shape, dtype=dtype)
  21. x[0, 1:] = -1
  22. x[1] = 2
  23. return x
  24. elif len(shape) == 2 and shape[0] == shape[1]:
  25. # always yield a positive definite matrix
  26. x = np.zeros(shape, dtype=dtype)
  27. j = np.arange(shape[0])
  28. x[j, j] = 2
  29. x[j[:-1], j[:-1]+1] = -1
  30. x[j[:-1]+1, j[:-1]] = -1
  31. return x
  32. else:
  33. np.random.seed(1234)
  34. return np.random.randn(*shape).astype(dtype)
  35. def _id(x):
  36. return x
  37. def assert_no_overwrite(call, shapes, dtypes=None):
  38. """
  39. Test that a call does not overwrite its input arguments
  40. """
  41. if dtypes is None:
  42. dtypes = [np.float32, np.float64, np.complex64, np.complex128]
  43. for dtype in dtypes:
  44. for order in ["C", "F"]:
  45. for faker in [_id, _FakeMatrix, _FakeMatrix2]:
  46. orig_inputs = [_get_array(s, dtype) for s in shapes]
  47. inputs = [faker(x.copy(order)) for x in orig_inputs]
  48. call(*inputs)
  49. msg = f"call modified inputs [{dtype!r}, {faker!r}]"
  50. for a, b in zip(inputs, orig_inputs):
  51. np.testing.assert_equal(a, b, err_msg=msg)