_arraytools.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. """
  2. Functions for acting on a axis of an array.
  3. """
  4. import numpy as np
  5. def axis_slice(a, start=None, stop=None, step=None, axis=-1):
  6. """Take a slice along axis 'axis' from 'a'.
  7. Parameters
  8. ----------
  9. a : numpy.ndarray
  10. The array to be sliced.
  11. start, stop, step : int or None
  12. The slice parameters.
  13. axis : int, optional
  14. The axis of `a` to be sliced.
  15. Examples
  16. --------
  17. >>> import numpy as np
  18. >>> from scipy.signal._arraytools import axis_slice
  19. >>> a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  20. >>> axis_slice(a, start=0, stop=1, axis=1)
  21. array([[1],
  22. [4],
  23. [7]])
  24. >>> axis_slice(a, start=1, axis=0)
  25. array([[4, 5, 6],
  26. [7, 8, 9]])
  27. Notes
  28. -----
  29. The keyword arguments start, stop and step are used by calling
  30. slice(start, stop, step). This implies axis_slice() does not
  31. handle its arguments the exactly the same as indexing. To select
  32. a single index k, for example, use
  33. axis_slice(a, start=k, stop=k+1)
  34. In this case, the length of the axis 'axis' in the result will
  35. be 1; the trivial dimension is not removed. (Use numpy.squeeze()
  36. to remove trivial axes.)
  37. """
  38. a_slice = [slice(None)] * a.ndim
  39. a_slice[axis] = slice(start, stop, step)
  40. b = a[tuple(a_slice)]
  41. return b
  42. def axis_reverse(a, axis=-1):
  43. """Reverse the 1-D slices of `a` along axis `axis`.
  44. Returns axis_slice(a, step=-1, axis=axis).
  45. """
  46. return axis_slice(a, step=-1, axis=axis)
  47. def odd_ext(x, n, axis=-1):
  48. """
  49. Odd extension at the boundaries of an array
  50. Generate a new ndarray by making an odd extension of `x` along an axis.
  51. Parameters
  52. ----------
  53. x : ndarray
  54. The array to be extended.
  55. n : int
  56. The number of elements by which to extend `x` at each end of the axis.
  57. axis : int, optional
  58. The axis along which to extend `x`. Default is -1.
  59. Examples
  60. --------
  61. >>> import numpy as np
  62. >>> from scipy.signal._arraytools import odd_ext
  63. >>> a = np.array([[1, 2, 3, 4, 5], [0, 1, 4, 9, 16]])
  64. >>> odd_ext(a, 2)
  65. array([[-1, 0, 1, 2, 3, 4, 5, 6, 7],
  66. [-4, -1, 0, 1, 4, 9, 16, 23, 28]])
  67. Odd extension is a "180 degree rotation" at the endpoints of the original
  68. array:
  69. >>> t = np.linspace(0, 1.5, 100)
  70. >>> a = 0.9 * np.sin(2 * np.pi * t**2)
  71. >>> b = odd_ext(a, 40)
  72. >>> import matplotlib.pyplot as plt
  73. >>> plt.plot(np.arange(-40, 140), b, 'b', lw=1, label='odd extension')
  74. >>> plt.plot(np.arange(100), a, 'r', lw=2, label='original')
  75. >>> plt.legend(loc='best')
  76. >>> plt.show()
  77. """
  78. if n < 1:
  79. return x
  80. if n > x.shape[axis] - 1:
  81. raise ValueError(("The extension length n (%d) is too big. " +
  82. "It must not exceed x.shape[axis]-1, which is %d.")
  83. % (n, x.shape[axis] - 1))
  84. left_end = axis_slice(x, start=0, stop=1, axis=axis)
  85. left_ext = axis_slice(x, start=n, stop=0, step=-1, axis=axis)
  86. right_end = axis_slice(x, start=-1, axis=axis)
  87. right_ext = axis_slice(x, start=-2, stop=-(n + 2), step=-1, axis=axis)
  88. ext = np.concatenate((2 * left_end - left_ext,
  89. x,
  90. 2 * right_end - right_ext),
  91. axis=axis)
  92. return ext
  93. def even_ext(x, n, axis=-1):
  94. """
  95. Even extension at the boundaries of an array
  96. Generate a new ndarray by making an even extension of `x` along an axis.
  97. Parameters
  98. ----------
  99. x : ndarray
  100. The array to be extended.
  101. n : int
  102. The number of elements by which to extend `x` at each end of the axis.
  103. axis : int, optional
  104. The axis along which to extend `x`. Default is -1.
  105. Examples
  106. --------
  107. >>> import numpy as np
  108. >>> from scipy.signal._arraytools import even_ext
  109. >>> a = np.array([[1, 2, 3, 4, 5], [0, 1, 4, 9, 16]])
  110. >>> even_ext(a, 2)
  111. array([[ 3, 2, 1, 2, 3, 4, 5, 4, 3],
  112. [ 4, 1, 0, 1, 4, 9, 16, 9, 4]])
  113. Even extension is a "mirror image" at the boundaries of the original array:
  114. >>> t = np.linspace(0, 1.5, 100)
  115. >>> a = 0.9 * np.sin(2 * np.pi * t**2)
  116. >>> b = even_ext(a, 40)
  117. >>> import matplotlib.pyplot as plt
  118. >>> plt.plot(np.arange(-40, 140), b, 'b', lw=1, label='even extension')
  119. >>> plt.plot(np.arange(100), a, 'r', lw=2, label='original')
  120. >>> plt.legend(loc='best')
  121. >>> plt.show()
  122. """
  123. if n < 1:
  124. return x
  125. if n > x.shape[axis] - 1:
  126. raise ValueError(("The extension length n (%d) is too big. " +
  127. "It must not exceed x.shape[axis]-1, which is %d.")
  128. % (n, x.shape[axis] - 1))
  129. left_ext = axis_slice(x, start=n, stop=0, step=-1, axis=axis)
  130. right_ext = axis_slice(x, start=-2, stop=-(n + 2), step=-1, axis=axis)
  131. ext = np.concatenate((left_ext,
  132. x,
  133. right_ext),
  134. axis=axis)
  135. return ext
  136. def const_ext(x, n, axis=-1):
  137. """
  138. Constant extension at the boundaries of an array
  139. Generate a new ndarray that is a constant extension of `x` along an axis.
  140. The extension repeats the values at the first and last element of
  141. the axis.
  142. Parameters
  143. ----------
  144. x : ndarray
  145. The array to be extended.
  146. n : int
  147. The number of elements by which to extend `x` at each end of the axis.
  148. axis : int, optional
  149. The axis along which to extend `x`. Default is -1.
  150. Examples
  151. --------
  152. >>> import numpy as np
  153. >>> from scipy.signal._arraytools import const_ext
  154. >>> a = np.array([[1, 2, 3, 4, 5], [0, 1, 4, 9, 16]])
  155. >>> const_ext(a, 2)
  156. array([[ 1, 1, 1, 2, 3, 4, 5, 5, 5],
  157. [ 0, 0, 0, 1, 4, 9, 16, 16, 16]])
  158. Constant extension continues with the same values as the endpoints of the
  159. array:
  160. >>> t = np.linspace(0, 1.5, 100)
  161. >>> a = 0.9 * np.sin(2 * np.pi * t**2)
  162. >>> b = const_ext(a, 40)
  163. >>> import matplotlib.pyplot as plt
  164. >>> plt.plot(np.arange(-40, 140), b, 'b', lw=1, label='constant extension')
  165. >>> plt.plot(np.arange(100), a, 'r', lw=2, label='original')
  166. >>> plt.legend(loc='best')
  167. >>> plt.show()
  168. """
  169. if n < 1:
  170. return x
  171. left_end = axis_slice(x, start=0, stop=1, axis=axis)
  172. ones_shape = [1] * x.ndim
  173. ones_shape[axis] = n
  174. ones = np.ones(ones_shape, dtype=x.dtype)
  175. left_ext = ones * left_end
  176. right_end = axis_slice(x, start=-1, axis=axis)
  177. right_ext = ones * right_end
  178. ext = np.concatenate((left_ext,
  179. x,
  180. right_ext),
  181. axis=axis)
  182. return ext
  183. def zero_ext(x, n, axis=-1):
  184. """
  185. Zero padding at the boundaries of an array
  186. Generate a new ndarray that is a zero-padded extension of `x` along
  187. an axis.
  188. Parameters
  189. ----------
  190. x : ndarray
  191. The array to be extended.
  192. n : int
  193. The number of elements by which to extend `x` at each end of the
  194. axis.
  195. axis : int, optional
  196. The axis along which to extend `x`. Default is -1.
  197. Examples
  198. --------
  199. >>> import numpy as np
  200. >>> from scipy.signal._arraytools import zero_ext
  201. >>> a = np.array([[1, 2, 3, 4, 5], [0, 1, 4, 9, 16]])
  202. >>> zero_ext(a, 2)
  203. array([[ 0, 0, 1, 2, 3, 4, 5, 0, 0],
  204. [ 0, 0, 0, 1, 4, 9, 16, 0, 0]])
  205. """
  206. if n < 1:
  207. return x
  208. zeros_shape = list(x.shape)
  209. zeros_shape[axis] = n
  210. zeros = np.zeros(zeros_shape, dtype=x.dtype)
  211. ext = np.concatenate((zeros, x, zeros), axis=axis)
  212. return ext
  213. def _validate_fs(fs, allow_none=True):
  214. """
  215. Check if the given sampling frequency is a scalar and raises an exception
  216. otherwise. If allow_none is False, also raises an exception for none
  217. sampling rates. Returns the sampling frequency as float or none if the
  218. input is none.
  219. """
  220. if fs is None:
  221. if not allow_none:
  222. raise ValueError("Sampling frequency can not be none.")
  223. else: # should be float
  224. if not np.isscalar(fs):
  225. raise ValueError("Sampling frequency fs must be a single scalar.")
  226. fs = float(fs)
  227. return fs