_mgc.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552
  1. import warnings
  2. import numpy as np
  3. from scipy._lib._array_api import xp_capabilities
  4. from scipy._lib._util import check_random_state, MapWrapper, rng_integers, _contains_nan
  5. from scipy._lib._bunch import _make_tuple_bunch
  6. from scipy.spatial.distance import cdist
  7. from scipy.ndimage import _measurements
  8. from ._stats import _local_correlations # type: ignore[import-not-found]
  9. from . import distributions
  10. __all__ = ['multiscale_graphcorr']
  11. # FROM MGCPY: https://github.com/neurodata/mgcpy
  12. class _ParallelP:
  13. """Helper function to calculate parallel p-value."""
  14. def __init__(self, x, y, random_states):
  15. self.x = x
  16. self.y = y
  17. self.random_states = random_states
  18. def __call__(self, index):
  19. order = self.random_states[index].permutation(self.y.shape[0])
  20. permy = self.y[order][:, order]
  21. # calculate permuted stats, store in null distribution
  22. perm_stat = _mgc_stat(self.x, permy)[0]
  23. return perm_stat
  24. def _perm_test(x, y, stat, reps=1000, workers=-1, random_state=None):
  25. r"""Helper function that calculates the p-value. See below for uses.
  26. Parameters
  27. ----------
  28. x, y : ndarray
  29. `x` and `y` have shapes ``(n, p)`` and ``(n, q)``.
  30. stat : float
  31. The sample test statistic.
  32. reps : int, optional
  33. The number of replications used to estimate the null when using the
  34. permutation test. The default is 1000 replications.
  35. workers : int or map-like callable, optional
  36. If `workers` is an int the population is subdivided into `workers`
  37. sections and evaluated in parallel (uses
  38. `multiprocessing.Pool <multiprocessing>`). Supply `-1` to use all cores
  39. available to the Process. Alternatively supply a map-like callable,
  40. such as `multiprocessing.Pool.map` for evaluating the population in
  41. parallel. This evaluation is carried out as `workers(func, iterable)`.
  42. Requires that `func` be pickleable.
  43. random_state : {None, int, `numpy.random.Generator`,
  44. `numpy.random.RandomState`}, optional
  45. If `seed` is None (or `np.random`), the `numpy.random.RandomState`
  46. singleton is used.
  47. If `seed` is an int, a new ``RandomState`` instance is used,
  48. seeded with `seed`.
  49. If `seed` is already a ``Generator`` or ``RandomState`` instance then
  50. that instance is used.
  51. Returns
  52. -------
  53. pvalue : float
  54. The sample test p-value.
  55. null_dist : list
  56. The approximated null distribution.
  57. """
  58. # generate seeds for each rep (change to new parallel random number
  59. # capabilities in numpy >= 1.17+)
  60. random_state = check_random_state(random_state)
  61. random_states = [np.random.RandomState(rng_integers(random_state, 1 << 32,
  62. size=4, dtype=np.uint32)) for _ in range(reps)]
  63. # parallelizes with specified workers over number of reps and set seeds
  64. parallelp = _ParallelP(x=x, y=y, random_states=random_states)
  65. with MapWrapper(workers) as mapwrapper:
  66. null_dist = np.array(list(mapwrapper(parallelp, range(reps))))
  67. # calculate p-value and significant permutation map through list
  68. pvalue = (1 + (null_dist >= stat).sum()) / (1 + reps)
  69. return pvalue, null_dist
  70. def _euclidean_dist(x):
  71. return cdist(x, x)
  72. MGCResult = _make_tuple_bunch('MGCResult',
  73. ['statistic', 'pvalue', 'mgc_dict'], [])
  74. @xp_capabilities(np_only=True)
  75. def multiscale_graphcorr(x, y, compute_distance=_euclidean_dist, reps=1000,
  76. workers=1, is_twosamp=False, random_state=None):
  77. r"""Computes the Multiscale Graph Correlation (MGC) test statistic.
  78. Specifically, for each point, MGC finds the :math:`k`-nearest neighbors for
  79. one property (e.g. cloud density), and the :math:`l`-nearest neighbors for
  80. the other property (e.g. grass wetness) [1]_. This pair :math:`(k, l)` is
  81. called the "scale". A priori, however, it is not know which scales will be
  82. most informative. So, MGC computes all distance pairs, and then efficiently
  83. computes the distance correlations for all scales. The local correlations
  84. illustrate which scales are relatively informative about the relationship.
  85. The key, therefore, to successfully discover and decipher relationships
  86. between disparate data modalities is to adaptively determine which scales
  87. are the most informative, and the geometric implication for the most
  88. informative scales. Doing so not only provides an estimate of whether the
  89. modalities are related, but also provides insight into how the
  90. determination was made. This is especially important in high-dimensional
  91. data, where simple visualizations do not reveal relationships to the
  92. unaided human eye. Characterizations of this implementation in particular
  93. have been derived from and benchmarked within in [2]_.
  94. Parameters
  95. ----------
  96. x, y : ndarray
  97. If ``x`` and ``y`` have shapes ``(n, p)`` and ``(n, q)`` where `n` is
  98. the number of samples and `p` and `q` are the number of dimensions,
  99. then the MGC independence test will be run. Alternatively, ``x`` and
  100. ``y`` can have shapes ``(n, n)`` if they are distance or similarity
  101. matrices, and ``compute_distance`` must be sent to ``None``. If ``x``
  102. and ``y`` have shapes ``(n, p)`` and ``(m, p)``, an unpaired
  103. two-sample MGC test will be run.
  104. compute_distance : callable, optional
  105. A function that computes the distance or similarity among the samples
  106. within each data matrix. Set to ``None`` if ``x`` and ``y`` are
  107. already distance matrices. The default uses the euclidean norm metric.
  108. If you are calling a custom function, either create the distance
  109. matrix before-hand or create a function of the form
  110. ``compute_distance(x)`` where `x` is the data matrix for which
  111. pairwise distances are calculated.
  112. reps : int, optional
  113. The number of replications used to estimate the null when using the
  114. permutation test. The default is ``1000``.
  115. workers : int or map-like callable, optional
  116. If ``workers`` is an int the population is subdivided into ``workers``
  117. sections and evaluated in parallel (uses ``multiprocessing.Pool
  118. <multiprocessing>``). Supply ``-1`` to use all cores available to the
  119. Process. Alternatively supply a map-like callable, such as
  120. ``multiprocessing.Pool.map`` for evaluating the p-value in parallel.
  121. This evaluation is carried out as ``workers(func, iterable)``.
  122. Requires that `func` be pickleable. The default is ``1``.
  123. is_twosamp : bool, optional
  124. If `True`, a two sample test will be run. If ``x`` and ``y`` have
  125. shapes ``(n, p)`` and ``(m, p)``, this optional will be overridden and
  126. set to ``True``. Set to ``True`` if ``x`` and ``y`` both have shapes
  127. ``(n, p)`` and a two sample test is desired. The default is ``False``.
  128. Note that this will not run if inputs are distance matrices.
  129. random_state : {None, int, `numpy.random.Generator`,
  130. `numpy.random.RandomState`}, optional
  131. If `seed` is None (or `np.random`), the `numpy.random.RandomState`
  132. singleton is used.
  133. If `seed` is an int, a new ``RandomState`` instance is used,
  134. seeded with `seed`.
  135. If `seed` is already a ``Generator`` or ``RandomState`` instance then
  136. that instance is used.
  137. Returns
  138. -------
  139. res : MGCResult
  140. An object containing attributes:
  141. statistic : float
  142. The sample MGC test statistic within ``[-1, 1]``.
  143. pvalue : float
  144. The p-value obtained via permutation.
  145. mgc_dict : dict
  146. Contains additional useful results:
  147. - mgc_map : ndarray
  148. A 2D representation of the latent geometry of the
  149. relationship.
  150. - opt_scale : (int, int)
  151. The estimated optimal scale as a ``(x, y)`` pair.
  152. - null_dist : list
  153. The null distribution derived from the permuted matrices.
  154. See Also
  155. --------
  156. pearsonr : Pearson correlation coefficient and p-value for testing
  157. non-correlation.
  158. kendalltau : Calculates Kendall's tau.
  159. spearmanr : Calculates a Spearman rank-order correlation coefficient.
  160. Notes
  161. -----
  162. A description of the process of MGC and applications on neuroscience data
  163. can be found in [1]_. It is performed using the following steps:
  164. #. Two distance matrices :math:`D^X` and :math:`D^Y` are computed and
  165. modified to be mean zero columnwise. This results in two
  166. :math:`n \times n` distance matrices :math:`A` and :math:`B` (the
  167. centering and unbiased modification) [3]_.
  168. #. For all values :math:`k` and :math:`l` from :math:`1, ..., n`,
  169. * The :math:`k`-nearest neighbor and :math:`l`-nearest neighbor graphs
  170. are calculated for each property. Here, :math:`G_k (i, j)` indicates
  171. the :math:`k`-smallest values of the :math:`i`-th row of :math:`A`
  172. and :math:`H_l (i, j)` indicates the :math:`l` smallested values of
  173. the :math:`i`-th row of :math:`B`
  174. * Let :math:`\circ` denotes the entry-wise matrix product, then local
  175. correlations are summed and normalized using the following statistic:
  176. .. math::
  177. c^{kl} = \frac{\sum_{ij} A G_k B H_l}
  178. {\sqrt{\sum_{ij} A^2 G_k \times \sum_{ij} B^2 H_l}}
  179. #. The MGC test statistic is the smoothed optimal local correlation of
  180. :math:`\{ c^{kl} \}`. Denote the smoothing operation as :math:`R(\cdot)`
  181. (which essentially set all isolated large correlations) as 0 and
  182. connected large correlations the same as before, see [3]_.) MGC is,
  183. .. math::
  184. MGC_n (x, y) = \max_{(k, l)} R \left(c^{kl} \left( x_n, y_n \right)
  185. \right)
  186. The test statistic returns a value between :math:`(-1, 1)` since it is
  187. normalized.
  188. The p-value returned is calculated using a permutation test. This process
  189. is completed by first randomly permuting :math:`y` to estimate the null
  190. distribution and then calculating the probability of observing a test
  191. statistic, under the null, at least as extreme as the observed test
  192. statistic.
  193. MGC requires at least 5 samples to run with reliable results. It can also
  194. handle high-dimensional data sets.
  195. In addition, by manipulating the input data matrices, the two-sample
  196. testing problem can be reduced to the independence testing problem [4]_.
  197. Given sample data :math:`U` and :math:`V` of sizes :math:`p \times n`
  198. :math:`p \times m`, data matrix :math:`X` and :math:`Y` can be created as
  199. follows:
  200. .. math::
  201. X = [U | V] \in \mathcal{R}^{p \times (n + m)}
  202. Y = [0_{1 \times n} | 1_{1 \times m}] \in \mathcal{R}^{(n + m)}
  203. Then, the MGC statistic can be calculated as normal. This methodology can
  204. be extended to similar tests such as distance correlation [4]_.
  205. .. versionadded:: 1.4.0
  206. References
  207. ----------
  208. .. [1] Vogelstein, J. T., Bridgeford, E. W., Wang, Q., Priebe, C. E.,
  209. Maggioni, M., & Shen, C. (2019). Discovering and deciphering
  210. relationships across disparate data modalities. ELife.
  211. .. [2] Panda, S., Palaniappan, S., Xiong, J., Swaminathan, A.,
  212. Ramachandran, S., Bridgeford, E. W., ... Vogelstein, J. T. (2019).
  213. mgcpy: A Comprehensive High Dimensional Independence Testing Python
  214. Package. :arXiv:`1907.02088`
  215. .. [3] Shen, C., Priebe, C.E., & Vogelstein, J. T. (2019). From distance
  216. correlation to multiscale graph correlation. Journal of the American
  217. Statistical Association.
  218. .. [4] Shen, C. & Vogelstein, J. T. (2018). The Exact Equivalence of
  219. Distance and Kernel Methods for Hypothesis Testing.
  220. :arXiv:`1806.05514`
  221. Examples
  222. --------
  223. >>> import numpy as np
  224. >>> from scipy.stats import multiscale_graphcorr
  225. >>> x = np.arange(100)
  226. >>> y = x
  227. >>> res = multiscale_graphcorr(x, y)
  228. >>> res.statistic, res.pvalue
  229. (1.0, 0.001)
  230. To run an unpaired two-sample test,
  231. >>> x = np.arange(100)
  232. >>> y = np.arange(79)
  233. >>> res = multiscale_graphcorr(x, y)
  234. >>> res.statistic, res.pvalue # doctest: +SKIP
  235. (0.033258146255703246, 0.023)
  236. or, if shape of the inputs are the same,
  237. >>> x = np.arange(100)
  238. >>> y = x
  239. >>> res = multiscale_graphcorr(x, y, is_twosamp=True)
  240. >>> res.statistic, res.pvalue # doctest: +SKIP
  241. (-0.008021809890200488, 1.0)
  242. """
  243. if not isinstance(x, np.ndarray) or not isinstance(y, np.ndarray):
  244. raise ValueError("x and y must be ndarrays")
  245. # convert arrays of type (n,) to (n, 1)
  246. if x.ndim == 1:
  247. x = x[:, np.newaxis]
  248. elif x.ndim != 2:
  249. raise ValueError(f"Expected a 2-D array `x`, found shape {x.shape}")
  250. if y.ndim == 1:
  251. y = y[:, np.newaxis]
  252. elif y.ndim != 2:
  253. raise ValueError(f"Expected a 2-D array `y`, found shape {y.shape}")
  254. nx, px = x.shape
  255. ny, py = y.shape
  256. # check for NaNs
  257. _contains_nan(x, nan_policy='raise')
  258. _contains_nan(y, nan_policy='raise')
  259. # check for positive or negative infinity and raise error
  260. if np.sum(np.isinf(x)) > 0 or np.sum(np.isinf(y)) > 0:
  261. raise ValueError("Inputs contain infinities")
  262. if nx != ny:
  263. if px == py:
  264. # reshape x and y for two sample testing
  265. is_twosamp = True
  266. else:
  267. raise ValueError("Shape mismatch, x and y must have shape [n, p] "
  268. "and [n, q] or have shape [n, p] and [m, p].")
  269. if nx < 5 or ny < 5:
  270. raise ValueError("MGC requires at least 5 samples to give reasonable "
  271. "results.")
  272. # convert x and y to float
  273. x = x.astype(np.float64)
  274. y = y.astype(np.float64)
  275. # check if compute_distance_matrix if a callable()
  276. if not callable(compute_distance) and compute_distance is not None:
  277. raise ValueError("Compute_distance must be a function.")
  278. # check if number of reps exists, integer, or > 0 (if under 1000 raises
  279. # warning)
  280. if not isinstance(reps, int) or reps < 0:
  281. raise ValueError("Number of reps must be an integer greater than 0.")
  282. elif reps < 1000:
  283. msg = ("The number of replications is low (under 1000), and p-value "
  284. "calculations may be unreliable. Use the p-value result, with "
  285. "caution!")
  286. warnings.warn(msg, RuntimeWarning, stacklevel=2)
  287. if is_twosamp:
  288. if compute_distance is None:
  289. raise ValueError("Cannot run if inputs are distance matrices")
  290. x, y = _two_sample_transform(x, y)
  291. if compute_distance is not None:
  292. # compute distance matrices for x and y
  293. x = compute_distance(x)
  294. y = compute_distance(y)
  295. # calculate MGC stat
  296. stat, stat_dict = _mgc_stat(x, y)
  297. stat_mgc_map = stat_dict["stat_mgc_map"]
  298. opt_scale = stat_dict["opt_scale"]
  299. # calculate permutation MGC p-value
  300. pvalue, null_dist = _perm_test(x, y, stat, reps=reps, workers=workers,
  301. random_state=random_state)
  302. # save all stats (other than stat/p-value) in dictionary
  303. mgc_dict = {"mgc_map": stat_mgc_map,
  304. "opt_scale": opt_scale,
  305. "null_dist": null_dist}
  306. # create result object with alias for backward compatibility
  307. res = MGCResult(stat, pvalue, mgc_dict)
  308. res.stat = stat
  309. return res
  310. def _mgc_stat(distx, disty):
  311. r"""Helper function that calculates the MGC stat. See above for use.
  312. Parameters
  313. ----------
  314. distx, disty : ndarray
  315. `distx` and `disty` have shapes ``(n, p)`` and ``(n, q)`` or
  316. ``(n, n)`` and ``(n, n)``
  317. if distance matrices.
  318. Returns
  319. -------
  320. stat : float
  321. The sample MGC test statistic within ``[-1, 1]``.
  322. stat_dict : dict
  323. Contains additional useful additional returns containing the following
  324. keys:
  325. - stat_mgc_map : ndarray
  326. MGC-map of the statistics.
  327. - opt_scale : (float, float)
  328. The estimated optimal scale as a ``(x, y)`` pair.
  329. """
  330. # calculate MGC map and optimal scale
  331. stat_mgc_map = _local_correlations(distx, disty, global_corr='mgc')
  332. n, m = stat_mgc_map.shape
  333. if m == 1 or n == 1:
  334. # the global scale at is the statistic calculated at maximal nearest
  335. # neighbors. There is not enough local scale to search over, so
  336. # default to global scale
  337. stat = stat_mgc_map[m - 1][n - 1]
  338. opt_scale = m * n
  339. else:
  340. samp_size = len(distx) - 1
  341. # threshold to find connected region of significant local correlations
  342. sig_connect = _threshold_mgc_map(stat_mgc_map, samp_size)
  343. # maximum within the significant region
  344. stat, opt_scale = _smooth_mgc_map(sig_connect, stat_mgc_map)
  345. stat_dict = {"stat_mgc_map": stat_mgc_map,
  346. "opt_scale": opt_scale}
  347. return stat, stat_dict
  348. def _threshold_mgc_map(stat_mgc_map, samp_size):
  349. r"""
  350. Finds a connected region of significance in the MGC-map by thresholding.
  351. Parameters
  352. ----------
  353. stat_mgc_map : ndarray
  354. All local correlations within ``[-1,1]``.
  355. samp_size : int
  356. The sample size of original data.
  357. Returns
  358. -------
  359. sig_connect : ndarray
  360. A binary matrix with 1's indicating the significant region.
  361. """
  362. m, n = stat_mgc_map.shape
  363. # 0.02 is simply an empirical threshold, this can be set to 0.01 or 0.05
  364. # with varying levels of performance. Threshold is based on a beta
  365. # approximation.
  366. per_sig = 1 - (0.02 / samp_size) # Percentile to consider as significant
  367. threshold = samp_size * (samp_size - 3)/4 - 1/2 # Beta approximation
  368. threshold = distributions.beta.ppf(per_sig, threshold, threshold) * 2 - 1
  369. # the global scale at is the statistic calculated at maximal nearest
  370. # neighbors. Threshold is the maximum on the global and local scales
  371. threshold = max(threshold, stat_mgc_map[m - 1][n - 1])
  372. # find the largest connected component of significant correlations
  373. sig_connect = stat_mgc_map > threshold
  374. if np.sum(sig_connect) > 0:
  375. sig_connect, _ = _measurements.label(sig_connect)
  376. _, label_counts = np.unique(sig_connect, return_counts=True)
  377. # skip the first element in label_counts, as it is count(zeros)
  378. max_label = np.argmax(label_counts[1:]) + 1
  379. sig_connect = sig_connect == max_label
  380. else:
  381. sig_connect = np.array([[False]])
  382. return sig_connect
  383. def _smooth_mgc_map(sig_connect, stat_mgc_map):
  384. """Finds the smoothed maximal within the significant region R.
  385. If area of R is too small it returns the last local correlation. Otherwise,
  386. returns the maximum within significant_connected_region.
  387. Parameters
  388. ----------
  389. sig_connect : ndarray
  390. A binary matrix with 1's indicating the significant region.
  391. stat_mgc_map : ndarray
  392. All local correlations within ``[-1, 1]``.
  393. Returns
  394. -------
  395. stat : float
  396. The sample MGC statistic within ``[-1, 1]``.
  397. opt_scale: (float, float)
  398. The estimated optimal scale as an ``(x, y)`` pair.
  399. """
  400. m, n = stat_mgc_map.shape
  401. # the global scale at is the statistic calculated at maximal nearest
  402. # neighbors. By default, statistic and optimal scale are global.
  403. stat = stat_mgc_map[m - 1][n - 1]
  404. opt_scale = [m, n]
  405. if np.linalg.norm(sig_connect) != 0:
  406. # proceed only when the connected region's area is sufficiently large
  407. # 0.02 is simply an empirical threshold, this can be set to 0.01 or 0.05
  408. # with varying levels of performance
  409. if np.sum(sig_connect) >= np.ceil(0.02 * max(m, n)) * min(m, n):
  410. max_corr = max(stat_mgc_map[sig_connect])
  411. # find all scales within significant_connected_region that maximize
  412. # the local correlation
  413. max_corr_index = np.where((stat_mgc_map >= max_corr) & sig_connect)
  414. if max_corr >= stat:
  415. stat = max_corr
  416. k, l = max_corr_index
  417. one_d_indices = k * n + l # 2D to 1D indexing
  418. k = np.max(one_d_indices) // n
  419. l = np.max(one_d_indices) % n
  420. opt_scale = [k+1, l+1] # adding 1s to match R indexing
  421. return stat, opt_scale
  422. def _two_sample_transform(u, v):
  423. """Helper function that concatenates x and y for two sample MGC stat.
  424. See above for use.
  425. Parameters
  426. ----------
  427. u, v : ndarray
  428. `u` and `v` have shapes ``(n, p)`` and ``(m, p)``.
  429. Returns
  430. -------
  431. x : ndarray
  432. Concatenate `u` and `v` along the ``axis = 0``. `x` thus has shape
  433. ``(2n, p)``.
  434. y : ndarray
  435. Label matrix for `x` where 0 refers to samples that comes from `u` and
  436. 1 refers to samples that come from `v`. `y` thus has shape ``(2n, 1)``.
  437. """
  438. nx = u.shape[0]
  439. ny = v.shape[0]
  440. x = np.concatenate([u, v], axis=0)
  441. y = np.concatenate([np.zeros(nx), np.ones(ny)], axis=0).reshape(-1, 1)
  442. return x, y