_ni_support.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # Copyright (C) 2003-2005 Peter J. Verveer
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions
  5. # are met:
  6. #
  7. # 1. Redistributions of source code must retain the above copyright
  8. # notice, this list of conditions and the following disclaimer.
  9. #
  10. # 2. Redistributions in binary form must reproduce the above
  11. # copyright notice, this list of conditions and the following
  12. # disclaimer in the documentation and/or other materials provided
  13. # with the distribution.
  14. #
  15. # 3. The name of the author may not be used to endorse or promote
  16. # products derived from this software without specific prior
  17. # written permission.
  18. #
  19. # THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS
  20. # OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  21. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  22. # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
  23. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  24. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
  25. # GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  26. # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
  27. # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
  28. # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  29. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  30. from collections.abc import Iterable
  31. import operator
  32. import warnings
  33. import numpy as np
  34. def _extend_mode_to_code(mode, is_filter=False):
  35. """Convert an extension mode to the corresponding integer code.
  36. """
  37. if mode == 'nearest':
  38. return 0
  39. elif mode == 'wrap':
  40. return 1
  41. elif mode in ['reflect', 'grid-mirror']:
  42. return 2
  43. elif mode == 'mirror':
  44. return 3
  45. elif mode == 'constant':
  46. return 4
  47. elif mode == 'grid-wrap' and is_filter:
  48. return 1
  49. elif mode == 'grid-wrap':
  50. return 5
  51. elif mode == 'grid-constant' and is_filter:
  52. return 4
  53. elif mode == 'grid-constant':
  54. return 6
  55. else:
  56. raise RuntimeError('boundary mode not supported')
  57. def _normalize_sequence(input, rank):
  58. """If input is a scalar, create a sequence of length equal to the
  59. rank by duplicating the input. If input is a sequence,
  60. check if its length is equal to the length of array.
  61. """
  62. is_str = isinstance(input, str)
  63. if not is_str and np.iterable(input):
  64. normalized = list(input)
  65. if len(normalized) != rank:
  66. err = "sequence argument must have length equal to input rank"
  67. raise RuntimeError(err)
  68. else:
  69. normalized = [input] * rank
  70. return normalized
  71. def _get_output(output, input, shape=None, complex_output=False):
  72. if shape is None:
  73. shape = input.shape
  74. if output is None:
  75. if not complex_output:
  76. output = np.zeros(shape, dtype=input.dtype.name)
  77. else:
  78. complex_type = np.promote_types(input.dtype, np.complex64)
  79. output = np.zeros(shape, dtype=complex_type)
  80. elif isinstance(output, type | np.dtype):
  81. # Classes (like `np.float32`) and dtypes are interpreted as dtype
  82. if complex_output and np.dtype(output).kind != 'c':
  83. warnings.warn("promoting specified output dtype to complex", stacklevel=3)
  84. output = np.promote_types(output, np.complex64)
  85. output = np.zeros(shape, dtype=output)
  86. elif isinstance(output, str):
  87. output = np.dtype(output)
  88. if complex_output and output.kind != 'c':
  89. raise RuntimeError("output must have complex dtype")
  90. elif not issubclass(output.type, np.number):
  91. raise RuntimeError("output must have numeric dtype")
  92. output = np.zeros(shape, dtype=output)
  93. else:
  94. # output was supplied as an array
  95. output = np.asarray(output)
  96. if output.shape != shape:
  97. raise RuntimeError("output shape not correct")
  98. elif complex_output and output.dtype.kind != 'c':
  99. raise RuntimeError("output must have complex dtype")
  100. return output
  101. def _check_axes(axes, ndim):
  102. if axes is None:
  103. return tuple(range(ndim))
  104. elif np.isscalar(axes):
  105. axes = (operator.index(axes),)
  106. elif isinstance(axes, Iterable):
  107. for ax in axes:
  108. axes = tuple(operator.index(ax) for ax in axes)
  109. if ax < -ndim or ax > ndim - 1:
  110. raise ValueError(f"specified axis: {ax} is out of range")
  111. axes = tuple(ax % ndim if ax < 0 else ax for ax in axes)
  112. else:
  113. message = "axes must be an integer, iterable of integers, or None"
  114. raise ValueError(message)
  115. if len(tuple(set(axes))) != len(axes):
  116. raise ValueError("axes must be unique")
  117. return axes
  118. def _skip_if_dtype(arg):
  119. """'array or dtype' polymorphism.
  120. Return None for np.int8, dtype('float32') or 'f' etc
  121. arg for np.empty(3) etc
  122. """
  123. if isinstance(arg, str):
  124. return None
  125. if type(arg) is type:
  126. return None if issubclass(arg, np.generic) else arg
  127. else:
  128. return None if isinstance(arg, np.dtype) else arg