test_c_api.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import numpy as np
  2. from scipy._lib._array_api import xp_assert_close
  3. from scipy import ndimage
  4. from scipy.ndimage import _ctest
  5. from scipy.ndimage import _cytest
  6. from scipy._lib._ccallback import LowLevelCallable
  7. FILTER1D_FUNCTIONS = [
  8. lambda filter_size: _ctest.filter1d(filter_size),
  9. lambda filter_size: _cytest.filter1d(filter_size, with_signature=False),
  10. lambda filter_size: LowLevelCallable(
  11. _cytest.filter1d(filter_size, with_signature=True)
  12. ),
  13. lambda filter_size: LowLevelCallable.from_cython(
  14. _cytest, "_filter1d",
  15. _cytest.filter1d_capsule(filter_size),
  16. ),
  17. ]
  18. FILTER2D_FUNCTIONS = [
  19. lambda weights: _ctest.filter2d(weights),
  20. lambda weights: _cytest.filter2d(weights, with_signature=False),
  21. lambda weights: LowLevelCallable(_cytest.filter2d(weights, with_signature=True)),
  22. lambda weights: LowLevelCallable.from_cython(_cytest,
  23. "_filter2d",
  24. _cytest.filter2d_capsule(weights),),
  25. ]
  26. TRANSFORM_FUNCTIONS = [
  27. lambda shift: _ctest.transform(shift),
  28. lambda shift: _cytest.transform(shift, with_signature=False),
  29. lambda shift: LowLevelCallable(_cytest.transform(shift, with_signature=True)),
  30. lambda shift: LowLevelCallable.from_cython(_cytest,
  31. "_transform",
  32. _cytest.transform_capsule(shift),),
  33. ]
  34. def test_generic_filter():
  35. def filter2d(footprint_elements, weights):
  36. return (weights*footprint_elements).sum()
  37. def check(j):
  38. func = FILTER2D_FUNCTIONS[j]
  39. im = np.ones((20, 20))
  40. im[:10,:10] = 0
  41. footprint = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])
  42. footprint_size = np.count_nonzero(footprint)
  43. weights = np.ones(footprint_size)/footprint_size
  44. res = ndimage.generic_filter(im, func(weights),
  45. footprint=footprint)
  46. std = ndimage.generic_filter(im, filter2d, footprint=footprint,
  47. extra_arguments=(weights,))
  48. xp_assert_close(res, std, err_msg=f"#{j} failed")
  49. for j, func in enumerate(FILTER2D_FUNCTIONS):
  50. check(j)
  51. def test_generic_filter1d():
  52. def filter1d(input_line, output_line, filter_size):
  53. for i in range(output_line.size):
  54. output_line[i] = 0
  55. for j in range(filter_size):
  56. output_line[i] += input_line[i+j]
  57. output_line /= filter_size
  58. def check(j):
  59. func = FILTER1D_FUNCTIONS[j]
  60. im = np.tile(np.hstack((np.zeros(10), np.ones(10))), (10, 1))
  61. filter_size = 3
  62. res = ndimage.generic_filter1d(im, func(filter_size),
  63. filter_size)
  64. std = ndimage.generic_filter1d(im, filter1d, filter_size,
  65. extra_arguments=(filter_size,))
  66. xp_assert_close(res, std, err_msg=f"#{j} failed")
  67. for j, func in enumerate(FILTER1D_FUNCTIONS):
  68. check(j)
  69. def test_geometric_transform():
  70. def transform(output_coordinates, shift):
  71. return output_coordinates[0] - shift, output_coordinates[1] - shift
  72. def check(j):
  73. func = TRANSFORM_FUNCTIONS[j]
  74. im = np.arange(12).reshape(4, 3).astype(np.float64)
  75. shift = 0.5
  76. res = ndimage.geometric_transform(im, func(shift))
  77. std = ndimage.geometric_transform(im, transform, extra_arguments=(shift,))
  78. xp_assert_close(res, std, err_msg=f"#{j} failed")
  79. for j, func in enumerate(TRANSFORM_FUNCTIONS):
  80. check(j)