| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183 |
- """Decorators for Shapely functions."""
- import os
- import warnings
- from collections.abc import Callable, Iterable
- from functools import lru_cache, wraps
- from inspect import unwrap
- import numpy as np
- from shapely import lib
- from shapely.errors import UnsupportedGEOSVersionError
- class requires_geos:
- """Decorator to require a minimum GEOS version."""
- def __init__(self, version):
- """Create a decorator that requires a minimum GEOS version."""
- if version.count(".") != 2:
- raise ValueError("Version must be <major>.<minor>.<patch> format")
- self.version = tuple(int(x) for x in version.split("."))
- def __call__(self, func):
- """Return the wrapped function."""
- is_compatible = lib.geos_version >= self.version
- is_doc_build = os.environ.get("SPHINX_DOC_BUILD") == "1" # set in docs/conf.py
- if is_compatible and not is_doc_build:
- return func # return directly, do not change the docstring
- msg = "'{}' requires at least GEOS {}.{}.{}.".format(
- func.__name__, *self.version
- )
- if is_compatible:
- @wraps(func)
- def wrapped(*args, **kwargs):
- return func(*args, **kwargs)
- else:
- @wraps(func)
- def wrapped(*args, **kwargs):
- raise UnsupportedGEOSVersionError(msg)
- doc = wrapped.__doc__
- if doc:
- # Insert the message at the first double newline
- position = doc.find("\n\n") + 2
- # Figure out the indentation level
- indent = 0
- while True:
- if doc[position + indent] == " ":
- indent += 1
- else:
- break
- wrapped.__doc__ = doc.replace(
- "\n\n", "\n\n{}.. note:: {}\n\n".format(" " * indent, msg), 1
- )
- return wrapped
- def multithreading_enabled(func):
- """Enable multithreading.
- To do this, the writable flags of object type ndarrays are set to False.
- NB: multithreading also requires the GIL to be released, which is done in
- the C extension (ufuncs.c).
- """
- @wraps(func)
- def wrapped(*args, **kwargs):
- array_args = [
- arg for arg in args if isinstance(arg, np.ndarray) and arg.dtype == object
- ] + [
- arg
- for name, arg in kwargs.items()
- if name not in {"where", "out"}
- and isinstance(arg, np.ndarray)
- and arg.dtype == object
- ]
- old_flags = [arr.flags.writeable for arr in array_args]
- try:
- for arr in array_args:
- arr.flags.writeable = False
- return func(*args, **kwargs)
- finally:
- for arr, old_flag in zip(array_args, old_flags):
- arr.flags.writeable = old_flag
- return wrapped
- def deprecate_positional(
- should_be_kwargs: Iterable[str],
- category: type[Warning] = DeprecationWarning,
- ):
- """Show warning if positional arguments are used that should be keyword.
- Parameters
- ----------
- should_be_kwargs : Iterable[str]
- Names of parameters that should be passed as keyword arguments.
- category : type[Warning], optional (default: DeprecationWarning)
- Warning category to use for deprecation warnings.
- Returns
- -------
- callable
- Decorator function that adds positional argument deprecation warnings.
- Examples
- --------
- >>> from shapely.decorators import deprecate_positional
- >>> @deprecate_positional(['b', 'c'])
- ... def example(a, b, c=None):
- ... return a, b, c
- ...
- >>> example(1, 2) # doctest: +SKIP
- DeprecationWarning: positional argument `b` for `example` is deprecated. ...
- (1, 2, None)
- >>> example(1, b=2) # No warnings
- (1, 2, None)
- """
- def decorator(func: Callable):
- code = unwrap(func).__code__
- # positional parameters are the first co_argcount names
- pos_names = code.co_varnames[: code.co_argcount]
- # build a name -> index map
- name_to_idx = {name: idx for idx, name in enumerate(pos_names)}
- # pick out only those names we care about
- deprecate_positions = [
- (name_to_idx[name], name)
- for name in should_be_kwargs
- if name in name_to_idx
- ]
- # early exit if there are no deprecated positional args
- if not deprecate_positions:
- return func
- # earliest position where a warning could occur
- warn_from = min(deprecate_positions)[0]
- @lru_cache(10)
- def make_msg(n_args: int):
- used = [name for idx, name in deprecate_positions if idx < n_args]
- if len(used) == 1:
- args_txt = f"`{used[0]}`"
- plr = ""
- isare = "is"
- else:
- plr = "s"
- isare = "are"
- if len(used) == 2:
- args_txt = " and ".join(f"`{u}`" for u in used)
- else:
- args_txt = ", ".join(f"`{u}`" for u in used[:-1])
- args_txt += f", and `{used[-1]}`"
- return (
- f"positional argument{plr} {args_txt} for `{func.__name__}` "
- f"{isare} deprecated. Please use keyword argument{plr} instead."
- )
- @wraps(func)
- def wrapper(*args, **kwargs):
- result = func(*args, **kwargs)
- n = len(args)
- if n > warn_from:
- warnings.warn(make_msg(n), category=category, stacklevel=2)
- return result
- return wrapper
- return decorator
|