vq.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832
  1. """
  2. K-means clustering and vector quantization (:mod:`scipy.cluster.vq`)
  3. ====================================================================
  4. Provides routines for k-means clustering, generating code books
  5. from k-means models and quantizing vectors by comparing them with
  6. centroids in a code book.
  7. .. autosummary::
  8. :toctree: generated/
  9. whiten -- Normalize a group of observations so each feature has unit variance
  10. vq -- Calculate code book membership of a set of observation vectors
  11. kmeans -- Perform k-means on a set of observation vectors forming k clusters
  12. kmeans2 -- A different implementation of k-means with more methods
  13. -- for initializing centroids
  14. Background information
  15. ----------------------
  16. The k-means algorithm takes as input the number of clusters to
  17. generate, k, and a set of observation vectors to cluster. It
  18. returns a set of centroids, one for each of the k clusters. An
  19. observation vector is classified with the cluster number or
  20. centroid index of the centroid closest to it.
  21. A vector v belongs to cluster i if it is closer to centroid i than
  22. any other centroid. If v belongs to i, we say centroid i is the
  23. dominating centroid of v. The k-means algorithm tries to
  24. minimize distortion, which is defined as the sum of the squared distances
  25. between each observation vector and its dominating centroid.
  26. The minimization is achieved by iteratively reclassifying
  27. the observations into clusters and recalculating the centroids until
  28. a configuration is reached in which the centroids are stable. One can
  29. also define a maximum number of iterations.
  30. Since vector quantization is a natural application for k-means,
  31. information theory terminology is often used. The centroid index
  32. or cluster index is also referred to as a "code" and the table
  33. mapping codes to centroids and, vice versa, is often referred to as a
  34. "code book". The result of k-means, a set of centroids, can be
  35. used to quantize vectors. Quantization aims to find an encoding of
  36. vectors that reduces the expected distortion.
  37. All routines expect obs to be an M by N array, where the rows are
  38. the observation vectors. The codebook is a k by N array, where the
  39. ith row is the centroid of code word i. The observation vectors
  40. and centroids have the same feature dimension.
  41. As an example, suppose we wish to compress a 24-bit color image
  42. (each pixel is represented by one byte for red, one for blue, and
  43. one for green) before sending it over the web. By using a smaller
  44. 8-bit encoding, we can reduce the amount of data by two
  45. thirds. Ideally, the colors for each of the 256 possible 8-bit
  46. encoding values should be chosen to minimize distortion of the
  47. color. Running k-means with k=256 generates a code book of 256
  48. codes, which fills up all possible 8-bit sequences. Instead of
  49. sending a 3-byte value for each pixel, the 8-bit centroid index
  50. (or code word) of the dominating centroid is transmitted. The code
  51. book is also sent over the wire so each 8-bit code can be
  52. translated back to a 24-bit pixel value representation. If the
  53. image of interest was of an ocean, we would expect many 24-bit
  54. blues to be represented by 8-bit codes. If it was an image of a
  55. human face, more flesh-tone colors would be represented in the
  56. code book.
  57. """
  58. import warnings
  59. import numpy as np
  60. from collections import deque
  61. from scipy._lib._array_api import (_asarray, array_namespace, is_lazy_array,
  62. xp_capabilities, xp_copy, xp_size)
  63. from scipy._lib._util import (check_random_state, rng_integers,
  64. _transition_to_rng)
  65. from scipy._lib import array_api_extra as xpx
  66. from scipy.spatial.distance import cdist
  67. from . import _vq
  68. __docformat__ = 'restructuredtext'
  69. __all__ = ['whiten', 'vq', 'kmeans', 'kmeans2']
  70. class ClusterError(Exception):
  71. pass
  72. @xp_capabilities()
  73. def whiten(obs, check_finite=None):
  74. """
  75. Normalize a group of observations on a per feature basis.
  76. Before running k-means, it is beneficial to rescale each feature
  77. dimension of the observation set by its standard deviation (i.e. "whiten"
  78. it - as in "white noise" where each frequency has equal power).
  79. Each feature is divided by its standard deviation across all observations
  80. to give it unit variance.
  81. Parameters
  82. ----------
  83. obs : ndarray
  84. Each row of the array is an observation. The
  85. columns are the features seen during each observation::
  86. # f0 f1 f2
  87. obs = [[ 1., 1., 1.], #o0
  88. [ 2., 2., 2.], #o1
  89. [ 3., 3., 3.], #o2
  90. [ 4., 4., 4.]] #o3
  91. check_finite : bool, optional
  92. Whether to check that the input matrices contain only finite numbers.
  93. Disabling may give a performance gain, but may result in problems
  94. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  95. Default: True for eager backends and False for lazy ones.
  96. Returns
  97. -------
  98. result : ndarray
  99. Contains the values in `obs` scaled by the standard deviation
  100. of each column.
  101. Examples
  102. --------
  103. >>> import numpy as np
  104. >>> from scipy.cluster.vq import whiten
  105. >>> features = np.array([[1.9, 2.3, 1.7],
  106. ... [1.5, 2.5, 2.2],
  107. ... [0.8, 0.6, 1.7,]])
  108. >>> whiten(features)
  109. array([[ 4.17944278, 2.69811351, 7.21248917],
  110. [ 3.29956009, 2.93273208, 9.33380951],
  111. [ 1.75976538, 0.7038557 , 7.21248917]])
  112. """
  113. xp = array_namespace(obs)
  114. if check_finite is None:
  115. check_finite = not is_lazy_array(obs)
  116. obs = _asarray(obs, check_finite=check_finite, xp=xp)
  117. std_dev = xp.std(obs, axis=0)
  118. zero_std_mask = std_dev == 0
  119. std_dev = xpx.at(std_dev, zero_std_mask).set(1.0)
  120. if check_finite and xp.any(zero_std_mask):
  121. warnings.warn("Some columns have standard deviation zero. "
  122. "The values of these columns will not change.",
  123. RuntimeWarning, stacklevel=2)
  124. return obs / std_dev
  125. @xp_capabilities(cpu_only=True, reason="uses spatial.distance.cdist",
  126. jax_jit=False, allow_dask_compute=True)
  127. def vq(obs, code_book, check_finite=True):
  128. """
  129. Assign codes from a code book to observations.
  130. Assigns a code from a code book to each observation. Each
  131. observation vector in the 'M' by 'N' `obs` array is compared with the
  132. centroids in the code book and assigned the code of the closest
  133. centroid.
  134. The features in `obs` should have unit variance, which can be
  135. achieved by passing them through the whiten function. The code
  136. book can be created with the k-means algorithm or a different
  137. encoding algorithm.
  138. Parameters
  139. ----------
  140. obs : ndarray
  141. Each row of the 'M' x 'N' array is an observation. The columns are
  142. the "features" seen during each observation. The features must be
  143. whitened first using the whiten function or something equivalent.
  144. code_book : ndarray
  145. The code book is usually generated using the k-means algorithm.
  146. Each row of the array holds a different code, and the columns are
  147. the features of the code::
  148. # f0 f1 f2 f3
  149. code_book = [[ 1., 2., 3., 4.], #c0
  150. [ 1., 2., 3., 4.], #c1
  151. [ 1., 2., 3., 4.]] #c2
  152. check_finite : bool, optional
  153. Whether to check that the input matrices contain only finite numbers.
  154. Disabling may give a performance gain, but may result in problems
  155. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  156. Default: True
  157. Returns
  158. -------
  159. code : ndarray
  160. A length M array holding the code book index for each observation.
  161. dist : ndarray
  162. The distortion (distance) between the observation and its nearest
  163. code.
  164. Examples
  165. --------
  166. >>> import numpy as np
  167. >>> from scipy.cluster.vq import vq
  168. >>> code_book = np.array([[1., 1., 1.],
  169. ... [2., 2., 2.]])
  170. >>> features = np.array([[1.9, 2.3, 1.7],
  171. ... [1.5, 2.5, 2.2],
  172. ... [0.8, 0.6, 1.7]])
  173. >>> vq(features, code_book)
  174. (array([1, 1, 0], dtype=int32), array([0.43588989, 0.73484692, 0.83066239]))
  175. """
  176. xp = array_namespace(obs, code_book)
  177. obs = _asarray(obs, xp=xp, check_finite=check_finite)
  178. code_book = _asarray(code_book, xp=xp, check_finite=check_finite)
  179. ct = xp.result_type(obs, code_book)
  180. if xp.isdtype(ct, kind='real floating'):
  181. c_obs = xp.astype(obs, ct, copy=False)
  182. c_code_book = xp.astype(code_book, ct, copy=False)
  183. c_obs = np.asarray(c_obs)
  184. c_code_book = np.asarray(c_code_book)
  185. result = _vq.vq(c_obs, c_code_book)
  186. return xp.asarray(result[0]), xp.asarray(result[1])
  187. return py_vq(obs, code_book, check_finite=False)
  188. def py_vq(obs, code_book, check_finite=True):
  189. """ Python version of vq algorithm.
  190. The algorithm computes the Euclidean distance between each
  191. observation and every frame in the code_book.
  192. Parameters
  193. ----------
  194. obs : ndarray
  195. Expects a rank 2 array. Each row is one observation.
  196. code_book : ndarray
  197. Code book to use. Same format than obs. Should have same number of
  198. features (e.g., columns) than obs.
  199. check_finite : bool, optional
  200. Whether to check that the input matrices contain only finite numbers.
  201. Disabling may give a performance gain, but may result in problems
  202. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  203. Default: True
  204. Returns
  205. -------
  206. code : ndarray
  207. code[i] gives the label of the ith obversation; its code is
  208. code_book[code[i]].
  209. mind_dist : ndarray
  210. min_dist[i] gives the distance between the ith observation and its
  211. corresponding code.
  212. Notes
  213. -----
  214. This function is slower than the C version but works for
  215. all input types. If the inputs have the wrong types for the
  216. C versions of the function, this one is called as a last resort.
  217. It is about 20 times slower than the C version.
  218. """
  219. xp = array_namespace(obs, code_book)
  220. obs = _asarray(obs, xp=xp, check_finite=check_finite)
  221. code_book = _asarray(code_book, xp=xp, check_finite=check_finite)
  222. if obs.ndim != code_book.ndim:
  223. raise ValueError("Observation and code_book should have the same rank")
  224. if obs.ndim == 1:
  225. obs = obs[:, xp.newaxis]
  226. code_book = code_book[:, xp.newaxis]
  227. # Once `cdist` has array API support, this `xp.asarray` call can be removed
  228. dist = xp.asarray(cdist(obs, code_book))
  229. code = xp.argmin(dist, axis=1)
  230. min_dist = xp.min(dist, axis=1)
  231. return code, min_dist
  232. def _kmeans(obs, guess, thresh=1e-5, xp=None):
  233. """ "raw" version of k-means.
  234. Returns
  235. -------
  236. code_book
  237. The lowest distortion codebook found.
  238. avg_dist
  239. The average distance a observation is from a code in the book.
  240. Lower means the code_book matches the data better.
  241. See Also
  242. --------
  243. kmeans : wrapper around k-means
  244. Examples
  245. --------
  246. Note: not whitened in this example.
  247. >>> import numpy as np
  248. >>> from scipy.cluster.vq import _kmeans
  249. >>> features = np.array([[ 1.9,2.3],
  250. ... [ 1.5,2.5],
  251. ... [ 0.8,0.6],
  252. ... [ 0.4,1.8],
  253. ... [ 1.0,1.0]])
  254. >>> book = np.array((features[0],features[2]))
  255. >>> _kmeans(features,book)
  256. (array([[ 1.7 , 2.4 ],
  257. [ 0.73333333, 1.13333333]]), 0.40563916697728591)
  258. """
  259. xp = np if xp is None else xp
  260. code_book = guess
  261. diff = xp.inf
  262. prev_avg_dists = deque([diff], maxlen=2)
  263. np_obs = np.asarray(obs)
  264. while diff > thresh:
  265. # compute membership and distances between obs and code_book
  266. obs_code, distort = vq(obs, code_book, check_finite=False)
  267. prev_avg_dists.append(xp.mean(distort, axis=-1))
  268. # recalc code_book as centroids of associated obs
  269. obs_code = np.asarray(obs_code)
  270. code_book, has_members = _vq.update_cluster_means(np_obs, obs_code,
  271. code_book.shape[0])
  272. code_book = code_book[has_members]
  273. code_book = xp.asarray(code_book)
  274. diff = xp.abs(prev_avg_dists[0] - prev_avg_dists[1])
  275. return code_book, prev_avg_dists[1]
  276. @xp_capabilities(cpu_only=True, jax_jit=False, allow_dask_compute=True)
  277. @_transition_to_rng("seed")
  278. def kmeans(obs, k_or_guess, iter=20, thresh=1e-5, check_finite=True,
  279. *, rng=None):
  280. """
  281. Performs k-means on a set of observation vectors forming k clusters.
  282. The k-means algorithm adjusts the classification of the observations
  283. into clusters and updates the cluster centroids until the position of
  284. the centroids is stable over successive iterations. In this
  285. implementation of the algorithm, the stability of the centroids is
  286. determined by comparing the absolute value of the change in the average
  287. Euclidean distance between the observations and their corresponding
  288. centroids against a threshold. This yields
  289. a code book mapping centroids to codes and vice versa.
  290. Parameters
  291. ----------
  292. obs : ndarray
  293. Each row of the M by N array is an observation vector. The
  294. columns are the features seen during each observation.
  295. The features must be whitened first with the `whiten` function.
  296. k_or_guess : int or ndarray
  297. The number of centroids to generate. A code is assigned to
  298. each centroid, which is also the row index of the centroid
  299. in the code_book matrix generated.
  300. The initial k centroids are chosen by randomly selecting
  301. observations from the observation matrix. Alternatively,
  302. passing a k by N array specifies the initial k centroids.
  303. iter : int, optional
  304. The number of times to run k-means, returning the codebook
  305. with the lowest distortion. This argument is ignored if
  306. initial centroids are specified with an array for the
  307. ``k_or_guess`` parameter. This parameter does not represent the
  308. number of iterations of the k-means algorithm.
  309. thresh : float, optional
  310. Terminates the k-means algorithm if the change in
  311. distortion since the last k-means iteration is less than
  312. or equal to threshold.
  313. check_finite : bool, optional
  314. Whether to check that the input matrices contain only finite numbers.
  315. Disabling may give a performance gain, but may result in problems
  316. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  317. Default: True
  318. rng : `numpy.random.Generator`, optional
  319. Pseudorandom number generator state. When `rng` is None, a new
  320. `numpy.random.Generator` is created using entropy from the
  321. operating system. Types other than `numpy.random.Generator` are
  322. passed to `numpy.random.default_rng` to instantiate a ``Generator``.
  323. Returns
  324. -------
  325. codebook : ndarray
  326. A k by N array of k centroids. The ith centroid
  327. codebook[i] is represented with the code i. The centroids
  328. and codes generated represent the lowest distortion seen,
  329. not necessarily the globally minimal distortion.
  330. Note that the number of centroids is not necessarily the same as the
  331. ``k_or_guess`` parameter, because centroids assigned to no observations
  332. are removed during iterations.
  333. distortion : float
  334. The mean (non-squared) Euclidean distance between the observations
  335. passed and the centroids generated. Note the difference to the standard
  336. definition of distortion in the context of the k-means algorithm, which
  337. is the sum of the squared distances.
  338. See Also
  339. --------
  340. kmeans2 : a different implementation of k-means clustering
  341. with more methods for generating initial centroids but without
  342. using a distortion change threshold as a stopping criterion.
  343. whiten : must be called prior to passing an observation matrix
  344. to kmeans.
  345. Notes
  346. -----
  347. For more functionalities or optimal performance, you can use
  348. `sklearn.cluster.KMeans <https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html>`_.
  349. `This <https://hdbscan.readthedocs.io/en/latest/performance_and_scalability.html#comparison-of-high-performance-implementations>`_
  350. is a benchmark result of several implementations.
  351. Examples
  352. --------
  353. >>> import numpy as np
  354. >>> from scipy.cluster.vq import vq, kmeans, whiten
  355. >>> import matplotlib.pyplot as plt
  356. >>> features = np.array([[ 1.9,2.3],
  357. ... [ 1.5,2.5],
  358. ... [ 0.8,0.6],
  359. ... [ 0.4,1.8],
  360. ... [ 0.1,0.1],
  361. ... [ 0.2,1.8],
  362. ... [ 2.0,0.5],
  363. ... [ 0.3,1.5],
  364. ... [ 1.0,1.0]])
  365. >>> whitened = whiten(features)
  366. >>> book = np.array((whitened[0],whitened[2]))
  367. >>> kmeans(whitened,book)
  368. (array([[ 2.3110306 , 2.86287398], # random
  369. [ 0.93218041, 1.24398691]]), 0.85684700941625547)
  370. >>> codes = 3
  371. >>> kmeans(whitened,codes)
  372. (array([[ 2.3110306 , 2.86287398], # random
  373. [ 1.32544402, 0.65607529],
  374. [ 0.40782893, 2.02786907]]), 0.5196582527686241)
  375. >>> # Create 50 datapoints in two clusters a and b
  376. >>> pts = 50
  377. >>> rng = np.random.default_rng()
  378. >>> a = rng.multivariate_normal([0, 0], [[4, 1], [1, 4]], size=pts)
  379. >>> b = rng.multivariate_normal([30, 10],
  380. ... [[10, 2], [2, 1]],
  381. ... size=pts)
  382. >>> features = np.concatenate((a, b))
  383. >>> # Whiten data
  384. >>> whitened = whiten(features)
  385. >>> # Find 2 clusters in the data
  386. >>> codebook, distortion = kmeans(whitened, 2)
  387. >>> # Plot whitened data and cluster centers in red
  388. >>> plt.scatter(whitened[:, 0], whitened[:, 1])
  389. >>> plt.scatter(codebook[:, 0], codebook[:, 1], c='r')
  390. >>> plt.show()
  391. """
  392. if isinstance(k_or_guess, int):
  393. xp = array_namespace(obs)
  394. else:
  395. xp = array_namespace(obs, k_or_guess)
  396. obs = _asarray(obs, xp=xp, check_finite=check_finite)
  397. guess = _asarray(k_or_guess, xp=xp, check_finite=check_finite)
  398. if iter < 1:
  399. raise ValueError(f"iter must be at least 1, got {iter}")
  400. # Determine whether a count (scalar) or an initial guess (array) was passed.
  401. if xp_size(guess) != 1:
  402. if xp_size(guess) < 1:
  403. raise ValueError(f"Asked for 0 clusters. Initial book was {guess}")
  404. return _kmeans(obs, guess, thresh=thresh, xp=xp)
  405. # k_or_guess is a scalar, now verify that it's an integer
  406. k = int(guess)
  407. if k != guess:
  408. raise ValueError("If k_or_guess is a scalar, it must be an integer.")
  409. if k < 1:
  410. raise ValueError(f"Asked for {k} clusters.")
  411. rng = check_random_state(rng)
  412. # initialize best distance value to a large value
  413. best_dist = xp.inf
  414. for i in range(iter):
  415. # the initial code book is randomly selected from observations
  416. guess = _kpoints(obs, k, rng, xp)
  417. book, dist = _kmeans(obs, guess, thresh=thresh, xp=xp)
  418. if dist < best_dist:
  419. best_book = book
  420. best_dist = dist
  421. return best_book, best_dist
  422. def _kpoints(data, k, rng, xp):
  423. """Pick k points at random in data (one row = one observation).
  424. Parameters
  425. ----------
  426. data : ndarray
  427. Expect a rank 1 or 2 array. Rank 1 are assumed to describe one
  428. dimensional data, rank 2 multidimensional data, in which case one
  429. row is one observation.
  430. k : int
  431. Number of samples to generate.
  432. rng : `numpy.random.Generator` or `numpy.random.RandomState`
  433. Random number generator.
  434. Returns
  435. -------
  436. x : ndarray
  437. A 'k' by 'N' containing the initial centroids
  438. """
  439. idx = rng.choice(data.shape[0], size=int(k), replace=False)
  440. # convert to array with default integer dtype (avoids numpy#25607)
  441. idx = xp.asarray(idx, dtype=xp.asarray([1]).dtype)
  442. return xp.take(data, idx, axis=0)
  443. def _krandinit(data, k, rng, xp):
  444. """Returns k samples of a random variable whose parameters depend on data.
  445. More precisely, it returns k observations sampled from a Gaussian random
  446. variable whose mean and covariances are the ones estimated from the data.
  447. Parameters
  448. ----------
  449. data : ndarray
  450. Expect a rank 1 or 2 array. Rank 1 is assumed to describe 1-D
  451. data, rank 2 multidimensional data, in which case one
  452. row is one observation.
  453. k : int
  454. Number of samples to generate.
  455. rng : `numpy.random.Generator` or `numpy.random.RandomState`
  456. Random number generator.
  457. Returns
  458. -------
  459. x : ndarray
  460. A 'k' by 'N' containing the initial centroids
  461. """
  462. mu = xp.mean(data, axis=0)
  463. k = np.asarray(k)
  464. if data.ndim == 1:
  465. _cov = xpx.cov(data, xp=xp)
  466. x = rng.standard_normal(size=k)
  467. x = xp.asarray(x)
  468. x *= xp.sqrt(_cov)
  469. elif data.shape[1] > data.shape[0]:
  470. # initialize when the covariance matrix is rank deficient
  471. _, s, vh = xp.linalg.svd(data - mu, full_matrices=False)
  472. x = rng.standard_normal(size=(k, xp_size(s)))
  473. x = xp.asarray(x)
  474. sVh = s[:, None] * vh / xp.sqrt(data.shape[0] - xp.asarray(1.))
  475. x = x @ sVh
  476. else:
  477. _cov = xpx.atleast_nd(xpx.cov(data.T, xp=xp), ndim=2, xp=xp)
  478. # k rows, d cols (one row = one obs)
  479. # Generate k sample of a random variable ~ Gaussian(mu, cov)
  480. x = rng.standard_normal(size=(k, xp_size(mu)))
  481. x = xp.asarray(x)
  482. x = x @ xp.linalg.cholesky(_cov).T
  483. x += mu
  484. return x
  485. def _kpp(data, k, rng, xp):
  486. """ Picks k points in the data based on the kmeans++ method.
  487. Parameters
  488. ----------
  489. data : ndarray
  490. Expect a rank 1 or 2 array. Rank 1 is assumed to describe 1-D
  491. data, rank 2 multidimensional data, in which case one
  492. row is one observation.
  493. k : int
  494. Number of samples to generate.
  495. rng : `numpy.random.Generator` or `numpy.random.RandomState`
  496. Random number generator.
  497. Returns
  498. -------
  499. init : ndarray
  500. A 'k' by 'N' containing the initial centroids.
  501. References
  502. ----------
  503. .. [1] D. Arthur and S. Vassilvitskii, "k-means++: the advantages of
  504. careful seeding", Proceedings of the Eighteenth Annual ACM-SIAM Symposium
  505. on Discrete Algorithms, 2007.
  506. """
  507. ndim = len(data.shape)
  508. if ndim == 1:
  509. data = data[:, None]
  510. dims = data.shape[1]
  511. init = xp.empty((int(k), dims))
  512. for i in range(k):
  513. if i == 0:
  514. data_idx = rng_integers(rng, data.shape[0])
  515. else:
  516. D2 = cdist(init[:i,:], data, metric='sqeuclidean').min(axis=0)
  517. probs = D2/D2.sum()
  518. cumprobs = probs.cumsum()
  519. r = rng.uniform()
  520. cumprobs = np.asarray(cumprobs)
  521. data_idx = int(np.searchsorted(cumprobs, r))
  522. init = xpx.at(init)[i, :].set(data[data_idx, :])
  523. if ndim == 1:
  524. init = init[:, 0]
  525. return init
  526. _valid_init_meth = {'random': _krandinit, 'points': _kpoints, '++': _kpp}
  527. def _missing_warn():
  528. """Print a warning when called."""
  529. warnings.warn("One of the clusters is empty. "
  530. "Re-run kmeans with a different initialization.",
  531. stacklevel=3)
  532. def _missing_raise():
  533. """Raise a ClusterError when called."""
  534. raise ClusterError("One of the clusters is empty. "
  535. "Re-run kmeans with a different initialization.")
  536. _valid_miss_meth = {'warn': _missing_warn, 'raise': _missing_raise}
  537. @xp_capabilities(cpu_only=True, jax_jit=False, allow_dask_compute=True)
  538. @_transition_to_rng("seed")
  539. def kmeans2(data, k, iter=10, thresh=1e-5, minit='random',
  540. missing='warn', check_finite=True, *, rng=None):
  541. """
  542. Classify a set of observations into k clusters using the k-means algorithm.
  543. The algorithm attempts to minimize the Euclidean distance between
  544. observations and centroids. Several initialization methods are
  545. included.
  546. Parameters
  547. ----------
  548. data : ndarray
  549. A 'M' by 'N' array of 'M' observations in 'N' dimensions or a length
  550. 'M' array of 'M' 1-D observations.
  551. k : int or ndarray
  552. The number of clusters to form as well as the number of
  553. centroids to generate. If `minit` initialization string is
  554. 'matrix', or if a ndarray is given instead, it is
  555. interpreted as initial cluster to use instead.
  556. iter : int, optional
  557. Number of iterations of the k-means algorithm to run. Note
  558. that this differs in meaning from the iters parameter to
  559. the kmeans function.
  560. thresh : float, optional
  561. (not used yet)
  562. minit : str, optional
  563. Method for initialization. Available methods are 'random',
  564. 'points', '++' and 'matrix':
  565. 'random': generate k centroids from a Gaussian with mean and
  566. variance estimated from the data.
  567. 'points': choose k observations (rows) at random from data for
  568. the initial centroids.
  569. '++': choose k observations accordingly to the kmeans++ method
  570. (careful seeding)
  571. 'matrix': interpret the k parameter as a k by M (or length k
  572. array for 1-D data) array of initial centroids.
  573. missing : str, optional
  574. Method to deal with empty clusters. Available methods are
  575. 'warn' and 'raise':
  576. 'warn': give a warning and continue.
  577. 'raise': raise an ClusterError and terminate the algorithm.
  578. check_finite : bool, optional
  579. Whether to check that the input matrices contain only finite numbers.
  580. Disabling may give a performance gain, but may result in problems
  581. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  582. Default: True
  583. rng : `numpy.random.Generator`, optional
  584. Pseudorandom number generator state. When `rng` is None, a new
  585. `numpy.random.Generator` is created using entropy from the
  586. operating system. Types other than `numpy.random.Generator` are
  587. passed to `numpy.random.default_rng` to instantiate a ``Generator``.
  588. Returns
  589. -------
  590. centroid : ndarray
  591. A 'k' by 'N' array of centroids found at the last iteration of
  592. k-means.
  593. label : ndarray
  594. label[i] is the code or index of the centroid the
  595. ith observation is closest to.
  596. See Also
  597. --------
  598. kmeans
  599. References
  600. ----------
  601. .. [1] D. Arthur and S. Vassilvitskii, "k-means++: the advantages of
  602. careful seeding", Proceedings of the Eighteenth Annual ACM-SIAM Symposium
  603. on Discrete Algorithms, 2007.
  604. Examples
  605. --------
  606. >>> from scipy.cluster.vq import kmeans2
  607. >>> import matplotlib.pyplot as plt
  608. >>> import numpy as np
  609. Create z, an array with shape (100, 2) containing a mixture of samples
  610. from three multivariate normal distributions.
  611. >>> rng = np.random.default_rng()
  612. >>> a = rng.multivariate_normal([0, 6], [[2, 1], [1, 1.5]], size=45)
  613. >>> b = rng.multivariate_normal([2, 0], [[1, -1], [-1, 3]], size=30)
  614. >>> c = rng.multivariate_normal([6, 4], [[5, 0], [0, 1.2]], size=25)
  615. >>> z = np.concatenate((a, b, c))
  616. >>> rng.shuffle(z)
  617. Compute three clusters.
  618. >>> centroid, label = kmeans2(z, 3, minit='points')
  619. >>> centroid
  620. array([[ 2.22274463, -0.61666946], # may vary
  621. [ 0.54069047, 5.86541444],
  622. [ 6.73846769, 4.01991898]])
  623. How many points are in each cluster?
  624. >>> counts = np.bincount(label)
  625. >>> counts
  626. array([29, 51, 20]) # may vary
  627. Plot the clusters.
  628. >>> w0 = z[label == 0]
  629. >>> w1 = z[label == 1]
  630. >>> w2 = z[label == 2]
  631. >>> plt.plot(w0[:, 0], w0[:, 1], 'o', alpha=0.5, label='cluster 0')
  632. >>> plt.plot(w1[:, 0], w1[:, 1], 'd', alpha=0.5, label='cluster 1')
  633. >>> plt.plot(w2[:, 0], w2[:, 1], 's', alpha=0.5, label='cluster 2')
  634. >>> plt.plot(centroid[:, 0], centroid[:, 1], 'k*', label='centroids')
  635. >>> plt.axis('equal')
  636. >>> plt.legend(shadow=True)
  637. >>> plt.show()
  638. """
  639. if int(iter) < 1:
  640. raise ValueError(f"Invalid iter ({iter}), must be a positive integer.")
  641. try:
  642. miss_meth = _valid_miss_meth[missing]
  643. except KeyError as e:
  644. raise ValueError(f"Unknown missing method {missing!r}") from e
  645. if isinstance(k, int):
  646. xp = array_namespace(data)
  647. else:
  648. xp = array_namespace(data, k)
  649. data = _asarray(data, xp=xp, check_finite=check_finite)
  650. code_book = xp_copy(k, xp=xp)
  651. if data.ndim == 1:
  652. d = 1
  653. elif data.ndim == 2:
  654. d = data.shape[1]
  655. else:
  656. raise ValueError("Input of rank > 2 is not supported.")
  657. if xp_size(data) < 1 or xp_size(code_book) < 1:
  658. raise ValueError("Empty input is not supported.")
  659. # If k is not a single value, it should be compatible with data's shape
  660. if minit == 'matrix' or xp_size(code_book) > 1:
  661. if data.ndim != code_book.ndim:
  662. raise ValueError("k array doesn't match data rank")
  663. nc = code_book.shape[0]
  664. if data.ndim > 1 and code_book.shape[1] != d:
  665. raise ValueError("k array doesn't match data dimension")
  666. else:
  667. nc = int(code_book)
  668. if nc < 1:
  669. raise ValueError(
  670. f"Cannot ask kmeans2 for {nc} clusters (k was {code_book})"
  671. )
  672. elif nc != code_book:
  673. warnings.warn("k was not an integer, was converted.", stacklevel=2)
  674. try:
  675. init_meth = _valid_init_meth[minit]
  676. except KeyError as e:
  677. raise ValueError(f"Unknown init method {minit!r}") from e
  678. else:
  679. rng = check_random_state(rng)
  680. code_book = init_meth(data, code_book, rng, xp)
  681. data = np.asarray(data)
  682. code_book = np.asarray(code_book)
  683. for _ in range(iter):
  684. # Compute the nearest neighbor for each obs using the current code book
  685. label = vq(data, code_book, check_finite=check_finite)[0]
  686. # Update the code book by computing centroids
  687. new_code_book, has_members = _vq.update_cluster_means(data, label, nc)
  688. if not has_members.all():
  689. miss_meth()
  690. # Set the empty clusters to their previous positions
  691. new_code_book[~has_members] = code_book[~has_members]
  692. code_book = new_code_book
  693. return xp.asarray(code_book), xp.asarray(label)