| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- # Copyright (C) 2003-2005 Peter J. Verveer
- #
- # Redistribution and use in source and binary forms, with or without
- # modification, are permitted provided that the following conditions
- # are met:
- #
- # 1. Redistributions of source code must retain the above copyright
- # notice, this list of conditions and the following disclaimer.
- #
- # 2. Redistributions in binary form must reproduce the above
- # copyright notice, this list of conditions and the following
- # disclaimer in the documentation and/or other materials provided
- # with the distribution.
- #
- # 3. The name of the author may not be used to endorse or promote
- # products derived from this software without specific prior
- # written permission.
- #
- # THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS
- # OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
- # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
- # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
- # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
- # GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
- # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
- # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
- # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
- # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- from collections.abc import Iterable
- import operator
- import warnings
- import numpy as np
- def _extend_mode_to_code(mode, is_filter=False):
- """Convert an extension mode to the corresponding integer code.
- """
- if mode == 'nearest':
- return 0
- elif mode == 'wrap':
- return 1
- elif mode in ['reflect', 'grid-mirror']:
- return 2
- elif mode == 'mirror':
- return 3
- elif mode == 'constant':
- return 4
- elif mode == 'grid-wrap' and is_filter:
- return 1
- elif mode == 'grid-wrap':
- return 5
- elif mode == 'grid-constant' and is_filter:
- return 4
- elif mode == 'grid-constant':
- return 6
- else:
- raise RuntimeError('boundary mode not supported')
- def _normalize_sequence(input, rank):
- """If input is a scalar, create a sequence of length equal to the
- rank by duplicating the input. If input is a sequence,
- check if its length is equal to the length of array.
- """
- is_str = isinstance(input, str)
- if not is_str and np.iterable(input):
- normalized = list(input)
- if len(normalized) != rank:
- err = "sequence argument must have length equal to input rank"
- raise RuntimeError(err)
- else:
- normalized = [input] * rank
- return normalized
- def _get_output(output, input, shape=None, complex_output=False):
- if shape is None:
- shape = input.shape
- if output is None:
- if not complex_output:
- output = np.zeros(shape, dtype=input.dtype.name)
- else:
- complex_type = np.promote_types(input.dtype, np.complex64)
- output = np.zeros(shape, dtype=complex_type)
- elif isinstance(output, type | np.dtype):
- # Classes (like `np.float32`) and dtypes are interpreted as dtype
- if complex_output and np.dtype(output).kind != 'c':
- warnings.warn("promoting specified output dtype to complex", stacklevel=3)
- output = np.promote_types(output, np.complex64)
- output = np.zeros(shape, dtype=output)
- elif isinstance(output, str):
- output = np.dtype(output)
- if complex_output and output.kind != 'c':
- raise RuntimeError("output must have complex dtype")
- elif not issubclass(output.type, np.number):
- raise RuntimeError("output must have numeric dtype")
- output = np.zeros(shape, dtype=output)
- else:
- # output was supplied as an array
- output = np.asarray(output)
- if output.shape != shape:
- raise RuntimeError("output shape not correct")
- elif complex_output and output.dtype.kind != 'c':
- raise RuntimeError("output must have complex dtype")
- return output
- def _check_axes(axes, ndim):
- if axes is None:
- return tuple(range(ndim))
- elif np.isscalar(axes):
- axes = (operator.index(axes),)
- elif isinstance(axes, Iterable):
- for ax in axes:
- axes = tuple(operator.index(ax) for ax in axes)
- if ax < -ndim or ax > ndim - 1:
- raise ValueError(f"specified axis: {ax} is out of range")
- axes = tuple(ax % ndim if ax < 0 else ax for ax in axes)
- else:
- message = "axes must be an integer, iterable of integers, or None"
- raise ValueError(message)
- if len(tuple(set(axes))) != len(axes):
- raise ValueError("axes must be unique")
- return axes
- def _skip_if_dtype(arg):
- """'array or dtype' polymorphism.
- Return None for np.int8, dtype('float32') or 'f' etc
- arg for np.empty(3) etc
- """
- if isinstance(arg, str):
- return None
- if type(arg) is type:
- return None if issubclass(arg, np.generic) else arg
- else:
- return None if isinstance(arg, np.dtype) else arg
|