test_ccallback.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. from numpy.testing import assert_equal, assert_
  2. from pytest import raises as assert_raises
  3. import time
  4. import pytest
  5. import ctypes
  6. import threading
  7. from scipy._lib import _ccallback_c as _test_ccallback_cython
  8. from scipy._lib import _test_ccallback
  9. from scipy._lib._ccallback import LowLevelCallable
  10. ERROR_VALUE = 2.0
  11. def callback_python(a, user_data=None):
  12. if a == ERROR_VALUE:
  13. raise ValueError("bad value")
  14. if user_data is None:
  15. return a + 1
  16. else:
  17. return a + user_data
  18. def _get_cffi_func(base, signature):
  19. cffi = pytest.importorskip("cffi")
  20. # Get function address
  21. voidp = ctypes.cast(base, ctypes.c_void_p)
  22. address = voidp.value
  23. # Create corresponding cffi handle
  24. ffi = cffi.FFI()
  25. func = ffi.cast(signature, address)
  26. return func
  27. def _get_ctypes_data():
  28. value = ctypes.c_double(2.0)
  29. return ctypes.cast(ctypes.pointer(value), ctypes.c_voidp)
  30. def _get_cffi_data():
  31. cffi = pytest.importorskip("cffi")
  32. ffi = cffi.FFI()
  33. return ffi.new('double *', 2.0)
  34. CALLERS = {
  35. 'simple': _test_ccallback.test_call_simple,
  36. 'nodata': _test_ccallback.test_call_nodata,
  37. 'nonlocal': _test_ccallback.test_call_nonlocal,
  38. 'cython': _test_ccallback_cython.test_call_cython,
  39. }
  40. # These functions have signatures known to the callers
  41. FUNCS = {
  42. 'python': lambda: callback_python,
  43. 'capsule': lambda: _test_ccallback.test_get_plus1_capsule(),
  44. 'cython': lambda: LowLevelCallable.from_cython(_test_ccallback_cython,
  45. "plus1_cython"),
  46. 'ctypes': lambda: _test_ccallback_cython.plus1_ctypes,
  47. 'cffi': lambda: _get_cffi_func(_test_ccallback_cython.plus1_ctypes,
  48. 'double (*)(double, int *, void *)'),
  49. 'capsule_b': lambda: _test_ccallback.test_get_plus1b_capsule(),
  50. 'cython_b': lambda: LowLevelCallable.from_cython(_test_ccallback_cython,
  51. "plus1b_cython"),
  52. 'ctypes_b': lambda: _test_ccallback_cython.plus1b_ctypes,
  53. 'cffi_b': lambda: _get_cffi_func(_test_ccallback_cython.plus1b_ctypes,
  54. 'double (*)(double, double, int *, void *)'),
  55. }
  56. # These functions have signatures the callers don't know
  57. BAD_FUNCS = {
  58. 'capsule_bc': lambda: _test_ccallback.test_get_plus1bc_capsule(),
  59. 'cython_bc': lambda: LowLevelCallable.from_cython(_test_ccallback_cython,
  60. "plus1bc_cython"),
  61. 'ctypes_bc': lambda: _test_ccallback_cython.plus1bc_ctypes,
  62. 'cffi_bc': lambda: _get_cffi_func(
  63. _test_ccallback_cython.plus1bc_ctypes,
  64. 'double (*)(double, double, double, int *, void *)'
  65. ),
  66. }
  67. USER_DATAS = {
  68. 'ctypes': _get_ctypes_data,
  69. 'cffi': _get_cffi_data,
  70. 'capsule': _test_ccallback.test_get_data_capsule,
  71. }
  72. def test_callbacks():
  73. def check(caller, func, user_data):
  74. caller = CALLERS[caller]
  75. func = FUNCS[func]()
  76. user_data = USER_DATAS[user_data]()
  77. if func is callback_python:
  78. def func2(x):
  79. return func(x, 2.0)
  80. else:
  81. func2 = LowLevelCallable(func, user_data)
  82. func = LowLevelCallable(func)
  83. # Test basic call
  84. assert_equal(caller(func, 1.0), 2.0)
  85. # Test 'bad' value resulting to an error
  86. assert_raises(ValueError, caller, func, ERROR_VALUE)
  87. # Test passing in user_data
  88. assert_equal(caller(func2, 1.0), 3.0)
  89. for caller in sorted(CALLERS.keys()):
  90. for func in sorted(FUNCS.keys()):
  91. for user_data in sorted(USER_DATAS.keys()):
  92. check(caller, func, user_data)
  93. def test_bad_callbacks():
  94. def check(caller, func, user_data):
  95. caller = CALLERS[caller]
  96. user_data = USER_DATAS[user_data]()
  97. func = BAD_FUNCS[func]()
  98. if func is callback_python:
  99. def func2(x):
  100. return func(x, 2.0)
  101. else:
  102. func2 = LowLevelCallable(func, user_data)
  103. func = LowLevelCallable(func)
  104. # Test that basic call fails
  105. assert_raises(ValueError, caller, LowLevelCallable(func), 1.0)
  106. # Test that passing in user_data also fails
  107. assert_raises(ValueError, caller, func2, 1.0)
  108. # Test error message
  109. llfunc = LowLevelCallable(func)
  110. try:
  111. caller(llfunc, 1.0)
  112. except ValueError as err:
  113. msg = str(err)
  114. assert_(llfunc.signature in msg, msg)
  115. assert_('double (double, double, int *, void *)' in msg, msg)
  116. for caller in sorted(CALLERS.keys()):
  117. for func in sorted(BAD_FUNCS.keys()):
  118. for user_data in sorted(USER_DATAS.keys()):
  119. check(caller, func, user_data)
  120. def test_signature_override():
  121. caller = _test_ccallback.test_call_simple
  122. func = _test_ccallback.test_get_plus1_capsule()
  123. llcallable = LowLevelCallable(func, signature="bad signature")
  124. assert_equal(llcallable.signature, "bad signature")
  125. assert_raises(ValueError, caller, llcallable, 3)
  126. llcallable = LowLevelCallable(func, signature="double (double, int *, void *)")
  127. assert_equal(llcallable.signature, "double (double, int *, void *)")
  128. assert_equal(caller(llcallable, 3), 4)
  129. def test_threadsafety():
  130. def callback(a, caller):
  131. if a <= 0:
  132. return 1
  133. else:
  134. res = caller(lambda x: callback(x, caller), a - 1)
  135. return 2*res
  136. def check(caller):
  137. caller = CALLERS[caller]
  138. results = []
  139. count = 10
  140. def run():
  141. time.sleep(0.01)
  142. r = caller(lambda x: callback(x, caller), count)
  143. results.append(r)
  144. threads = [threading.Thread(target=run) for j in range(20)]
  145. for thread in threads:
  146. thread.start()
  147. for thread in threads:
  148. thread.join()
  149. assert_equal(results, [2.0**count]*len(threads))
  150. for caller in CALLERS.keys():
  151. check(caller)