hyp2f1_data.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. """This script evaluates scipy's implementation of hyp2f1 against mpmath's.
  2. Author: Albert Steppi
  3. This script is long running and generates a large output file. With default
  4. arguments, the generated file is roughly 700MB in size and it takes around
  5. 40 minutes using an Intel(R) Core(TM) i5-8250U CPU with n_jobs set to 8
  6. (full utilization). There are optional arguments which can be used to restrict
  7. (or enlarge) the computations performed. These are described below.
  8. The output of this script can be analyzed to identify suitable test cases and
  9. to find parameter and argument regions where hyp2f1 needs to be improved.
  10. The script has one mandatory positional argument for specifying the path to
  11. the location where the output file is to be placed, and 4 optional arguments
  12. --n_jobs, --grid_size, --regions, and --parameter_groups. --n_jobs specifies
  13. the number of processes to use if running in parallel. The default value is 1.
  14. The other optional arguments are explained below.
  15. Produces a tab separated values file with 11 columns. The first four columns
  16. contain the parameters a, b, c and the argument z. The next two contain |z| and
  17. a region code for which region of the complex plane belongs to. The regions are
  18. 0) z == 1
  19. 1) |z| < 0.9 and real(z) >= 0
  20. 2) |z| <= 1 and real(z) < 0
  21. 3) 0.9 <= |z| <= 1 and |1 - z| < 0.9:
  22. 4) 0.9 <= |z| <= 1 and |1 - z| >= 0.9 and real(z) >= 0:
  23. 5) 1 < |z| < 1.1 and |1 - z| >= 0.9 and real(z) >= 0
  24. 6) |z| > 1 and not in 5)
  25. The --regions optional argument allows the user to specify a list of regions
  26. to which computation will be restricted.
  27. Parameters a, b, c are taken from a 10 * 10 * 10 grid with values at
  28. -16, -8, -4, -2, -1, 1, 2, 4, 8, 16
  29. with random perturbations applied.
  30. There are 9 parameter groups handling the following cases.
  31. 1) A, B, C, B - A, C - A, C - B, C - A - B all non-integral.
  32. 2) B - A integral
  33. 3) C - A integral
  34. 4) C - B integral
  35. 5) C - A - B integral
  36. 6) A integral
  37. 7) B integral
  38. 8) C integral
  39. 9) Wider range with c - a - b > 0.
  40. The seventh column of the output file is an integer between 1 and 8 specifying
  41. the parameter group as above.
  42. The --parameter_groups optional argument allows the user to specify a list of
  43. parameter groups to which computation will be restricted.
  44. The argument z is taken from a grid in the box
  45. -box_size <= real(z) <= box_size, -box_size <= imag(z) <= box_size.
  46. with grid size specified using the optional command line argument --grid_size,
  47. and box_size specified with the command line argument --box_size.
  48. The default value of grid_size is 20 and the default value of box_size is 2.0,
  49. yielding a 20 * 20 grid in the box with corners -2-2j, -2+2j, 2-2j, 2+2j.
  50. The final four columns have the expected value of hyp2f1 for the given
  51. parameters and argument as calculated with mpmath, the observed value
  52. calculated with scipy's hyp2f1, the relative error, and the absolute error.
  53. As special cases of hyp2f1 are moved from the original Fortran implementation
  54. into Cython, this script can be used to ensure that no regressions occur and
  55. to point out where improvements are needed.
  56. """
  57. import os
  58. import csv
  59. import argparse
  60. import numpy as np
  61. from itertools import product
  62. from multiprocessing import Pool
  63. from scipy.special import hyp2f1
  64. from scipy.special.tests.test_hyp2f1 import mp_hyp2f1
  65. def get_region(z):
  66. """Assign numbers for regions where hyp2f1 must be handled differently."""
  67. if z == 1 + 0j:
  68. return 0
  69. elif abs(z) < 0.9 and z.real >= 0:
  70. return 1
  71. elif abs(z) <= 1 and z.real < 0:
  72. return 2
  73. elif 0.9 <= abs(z) <= 1 and abs(1 - z) < 0.9:
  74. return 3
  75. elif 0.9 <= abs(z) <= 1 and abs(1 - z) >= 0.9:
  76. return 4
  77. elif 1 < abs(z) < 1.1 and abs(1 - z) >= 0.9 and z.real >= 0:
  78. return 5
  79. else:
  80. return 6
  81. def get_result(a, b, c, z, group):
  82. """Get results for given parameter and value combination."""
  83. expected, observed = mp_hyp2f1(a, b, c, z), hyp2f1(a, b, c, z)
  84. if (
  85. np.isnan(observed) and np.isnan(expected) or
  86. expected == observed
  87. ):
  88. relative_error = 0.0
  89. absolute_error = 0.0
  90. elif np.isnan(observed):
  91. # Set error to infinity if result is nan when not expected to be.
  92. # Makes results easier to interpret.
  93. relative_error = float("inf")
  94. absolute_error = float("inf")
  95. else:
  96. absolute_error = abs(expected - observed)
  97. relative_error = absolute_error / abs(expected)
  98. return (
  99. a,
  100. b,
  101. c,
  102. z,
  103. abs(z),
  104. get_region(z),
  105. group,
  106. expected,
  107. observed,
  108. relative_error,
  109. absolute_error,
  110. )
  111. def get_result_no_mp(a, b, c, z, group):
  112. """Get results for given parameter and value combination."""
  113. expected, observed = complex('nan'), hyp2f1(a, b, c, z)
  114. relative_error, absolute_error = float('nan'), float('nan')
  115. return (
  116. a,
  117. b,
  118. c,
  119. z,
  120. abs(z),
  121. get_region(z),
  122. group,
  123. expected,
  124. observed,
  125. relative_error,
  126. absolute_error,
  127. )
  128. def get_results(params, Z, n_jobs=1, compute_mp=True):
  129. """Batch compute results for multiple parameter and argument values.
  130. Parameters
  131. ----------
  132. params : iterable
  133. iterable of tuples of floats (a, b, c) specifying parameter values
  134. a, b, c for hyp2f1
  135. Z : iterable of complex
  136. Arguments at which to evaluate hyp2f1
  137. n_jobs : Optional[int]
  138. Number of jobs for parallel execution.
  139. Returns
  140. -------
  141. list
  142. List of tuples of results values. See return value in source code
  143. of `get_result`.
  144. """
  145. input_ = (
  146. (a, b, c, z, group) for (a, b, c, group), z in product(params, Z)
  147. )
  148. with Pool(n_jobs) as pool:
  149. rows = pool.starmap(
  150. get_result if compute_mp else get_result_no_mp,
  151. input_
  152. )
  153. return rows
  154. def _make_hyp2f1_test_case(a, b, c, z, rtol):
  155. """Generate string for single test case as used in test_hyp2f1.py."""
  156. expected = mp_hyp2f1(a, b, c, z)
  157. return (
  158. " pytest.param(\n"
  159. " Hyp2f1TestCase(\n"
  160. f" a={a},\n"
  161. f" b={b},\n"
  162. f" c={c},\n"
  163. f" z={z},\n"
  164. f" expected={expected},\n"
  165. f" rtol={rtol},\n"
  166. " ),\n"
  167. " ),"
  168. )
  169. def make_hyp2f1_test_cases(rows):
  170. """Generate string for a list of test cases for test_hyp2f1.py.
  171. Parameters
  172. ----------
  173. rows : list
  174. List of lists of the form [a, b, c, z, rtol] where a, b, c, z are
  175. parameters and the argument for hyp2f1 and rtol is an expected
  176. relative error for the associated test case.
  177. Returns
  178. -------
  179. str
  180. String for a list of test cases. The output string can be printed
  181. or saved to a file and then copied into an argument for
  182. `pytest.mark.parameterize` within `scipy.special.tests.test_hyp2f1.py`.
  183. """
  184. result = "[\n"
  185. result += '\n'.join(
  186. _make_hyp2f1_test_case(a, b, c, z, rtol)
  187. for a, b, c, z, rtol in rows
  188. )
  189. result += "\n]"
  190. return result
  191. def main(
  192. outpath,
  193. n_jobs=1,
  194. box_size=2.0,
  195. grid_size=20,
  196. regions=None,
  197. parameter_groups=None,
  198. compute_mp=True,
  199. ):
  200. outpath = os.path.realpath(os.path.expanduser(outpath))
  201. random_state = np.random.RandomState(1234)
  202. # Parameters a, b, c selected near these values.
  203. root_params = np.array(
  204. [-16, -8, -4, -2, -1, 1, 2, 4, 8, 16]
  205. )
  206. # Perturbations to apply to root values.
  207. perturbations = 0.1 * random_state.random_sample(
  208. size=(3, len(root_params))
  209. )
  210. params = []
  211. # Parameter group 1
  212. # -----------------
  213. # No integer differences. This has been confirmed for the above seed.
  214. A = root_params + perturbations[0, :]
  215. B = root_params + perturbations[1, :]
  216. C = root_params + perturbations[2, :]
  217. params.extend(
  218. sorted(
  219. ((a, b, c, 1) for a, b, c in product(A, B, C)),
  220. key=lambda x: max(abs(x[0]), abs(x[1])),
  221. )
  222. )
  223. # Parameter group 2
  224. # -----------------
  225. # B - A an integer
  226. A = root_params + 0.5
  227. B = root_params + 0.5
  228. C = root_params + perturbations[1, :]
  229. params.extend(
  230. sorted(
  231. ((a, b, c, 2) for a, b, c in product(A, B, C)),
  232. key=lambda x: max(abs(x[0]), abs(x[1])),
  233. )
  234. )
  235. # Parameter group 3
  236. # -----------------
  237. # C - A an integer
  238. A = root_params + 0.5
  239. B = root_params + perturbations[1, :]
  240. C = root_params + 0.5
  241. params.extend(
  242. sorted(
  243. ((a, b, c, 3) for a, b, c in product(A, B, C)),
  244. key=lambda x: max(abs(x[0]), abs(x[1])),
  245. )
  246. )
  247. # Parameter group 4
  248. # -----------------
  249. # C - B an integer
  250. A = root_params + perturbations[0, :]
  251. B = root_params + 0.5
  252. C = root_params + 0.5
  253. params.extend(
  254. sorted(
  255. ((a, b, c, 4) for a, b, c in product(A, B, C)),
  256. key=lambda x: max(abs(x[0]), abs(x[1])),
  257. )
  258. )
  259. # Parameter group 5
  260. # -----------------
  261. # C - A - B an integer
  262. A = root_params + 0.25
  263. B = root_params + 0.25
  264. C = root_params + 0.5
  265. params.extend(
  266. sorted(
  267. ((a, b, c, 5) for a, b, c in product(A, B, C)),
  268. key=lambda x: max(abs(x[0]), abs(x[1])),
  269. )
  270. )
  271. # Parameter group 6
  272. # -----------------
  273. # A an integer
  274. A = root_params
  275. B = root_params + perturbations[0, :]
  276. C = root_params + perturbations[1, :]
  277. params.extend(
  278. sorted(
  279. ((a, b, c, 6) for a, b, c in product(A, B, C)),
  280. key=lambda x: max(abs(x[0]), abs(x[1])),
  281. )
  282. )
  283. # Parameter group 7
  284. # -----------------
  285. # B an integer
  286. A = root_params + perturbations[0, :]
  287. B = root_params
  288. C = root_params + perturbations[1, :]
  289. params.extend(
  290. sorted(
  291. ((a, b, c, 7) for a, b, c in product(A, B, C)),
  292. key=lambda x: max(abs(x[0]), abs(x[1])),
  293. )
  294. )
  295. # Parameter group 8
  296. # -----------------
  297. # C an integer
  298. A = root_params + perturbations[0, :]
  299. B = root_params + perturbations[1, :]
  300. C = root_params
  301. params.extend(
  302. sorted(
  303. ((a, b, c, 8) for a, b, c in product(A, B, C)),
  304. key=lambda x: max(abs(x[0]), abs(x[1])),
  305. )
  306. )
  307. # Parameter group 9
  308. # -----------------
  309. # Wide range of magnitudes, c - a - b > 0.
  310. phi = (1 + np.sqrt(5))/2
  311. P = phi**np.arange(16)
  312. P = np.hstack([-P, P])
  313. group_9_params = sorted(
  314. (
  315. (a, b, c, 9) for a, b, c in product(P, P, P) if c - a - b > 0
  316. ),
  317. key=lambda x: max(abs(x[0]), abs(x[1])),
  318. )
  319. if parameter_groups is not None:
  320. # Group 9 params only used if specified in arguments.
  321. params.extend(group_9_params)
  322. params = [
  323. (a, b, c, group) for a, b, c, group in params
  324. if group in parameter_groups
  325. ]
  326. # grid_size * grid_size grid in box with corners
  327. # -2 - 2j, -2 + 2j, 2 - 2j, 2 + 2j
  328. X, Y = np.meshgrid(
  329. np.linspace(-box_size, box_size, grid_size),
  330. np.linspace(-box_size, box_size, grid_size)
  331. )
  332. Z = X + Y * 1j
  333. Z = Z.flatten().tolist()
  334. # Add z = 1 + 0j (region 0).
  335. Z.append(1 + 0j)
  336. if regions is not None:
  337. Z = [z for z in Z if get_region(z) in regions]
  338. # Evaluate scipy and mpmath's hyp2f1 for all parameter combinations
  339. # above against all arguments in the grid Z
  340. rows = get_results(params, Z, n_jobs=n_jobs, compute_mp=compute_mp)
  341. with open(outpath, "w", newline="") as f:
  342. writer = csv.writer(f, delimiter="\t")
  343. writer.writerow(
  344. [
  345. "a",
  346. "b",
  347. "c",
  348. "z",
  349. "|z|",
  350. "region",
  351. "parameter_group",
  352. "expected", # mpmath's hyp2f1
  353. "observed", # scipy's hyp2f1
  354. "relative_error",
  355. "absolute_error",
  356. ]
  357. )
  358. for row in rows:
  359. writer.writerow(row)
  360. if __name__ == "__main__":
  361. parser = argparse.ArgumentParser(
  362. description="Test scipy's hyp2f1 against mpmath's on a grid in the"
  363. " complex plane over a grid of parameter values. Saves output to file"
  364. " specified in positional argument \"outpath\"."
  365. " Caution: With default arguments, the generated output file is"
  366. " roughly 700MB in size. Script may take several hours to finish if"
  367. " \"--n_jobs\" is set to 1."
  368. )
  369. parser.add_argument(
  370. "outpath", type=str, help="Path to output tsv file."
  371. )
  372. parser.add_argument(
  373. "--n_jobs",
  374. type=int,
  375. default=1,
  376. help="Number of jobs for multiprocessing.",
  377. )
  378. parser.add_argument(
  379. "--box_size",
  380. type=float,
  381. default=2.0,
  382. help="hyp2f1 is evaluated in box of side_length 2*box_size centered"
  383. " at the origin."
  384. )
  385. parser.add_argument(
  386. "--grid_size",
  387. type=int,
  388. default=20,
  389. help="hyp2f1 is evaluated on grid_size * grid_size grid in box of side"
  390. " length 2*box_size centered at the origin."
  391. )
  392. parser.add_argument(
  393. "--parameter_groups",
  394. type=int,
  395. nargs='+',
  396. default=None,
  397. help="Restrict to supplied parameter groups. See the Docstring for"
  398. " this module for more info on parameter groups. Calculate for all"
  399. " parameter groups by default."
  400. )
  401. parser.add_argument(
  402. "--regions",
  403. type=int,
  404. nargs='+',
  405. default=None,
  406. help="Restrict to argument z only within the supplied regions. See"
  407. " the Docstring for this module for more info on regions. Calculate"
  408. " for all regions by default."
  409. )
  410. parser.add_argument(
  411. "--no_mp",
  412. action='store_true',
  413. help="If this flag is set, do not compute results with mpmath. Saves"
  414. " time if results have already been computed elsewhere. Fills in"
  415. " \"expected\" column with None values."
  416. )
  417. args = parser.parse_args()
  418. compute_mp = not args.no_mp
  419. print(args.parameter_groups)
  420. main(
  421. args.outpath,
  422. n_jobs=args.n_jobs,
  423. box_size=args.box_size,
  424. grid_size=args.grid_size,
  425. parameter_groups=args.parameter_groups,
  426. regions=args.regions,
  427. compute_mp=compute_mp,
  428. )