decorators.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. """Decorators for Shapely functions."""
  2. import os
  3. import warnings
  4. from collections.abc import Callable, Iterable
  5. from functools import lru_cache, wraps
  6. from inspect import unwrap
  7. import numpy as np
  8. from shapely import lib
  9. from shapely.errors import UnsupportedGEOSVersionError
  10. class requires_geos:
  11. """Decorator to require a minimum GEOS version."""
  12. def __init__(self, version):
  13. """Create a decorator that requires a minimum GEOS version."""
  14. if version.count(".") != 2:
  15. raise ValueError("Version must be <major>.<minor>.<patch> format")
  16. self.version = tuple(int(x) for x in version.split("."))
  17. def __call__(self, func):
  18. """Return the wrapped function."""
  19. is_compatible = lib.geos_version >= self.version
  20. is_doc_build = os.environ.get("SPHINX_DOC_BUILD") == "1" # set in docs/conf.py
  21. if is_compatible and not is_doc_build:
  22. return func # return directly, do not change the docstring
  23. msg = "'{}' requires at least GEOS {}.{}.{}.".format(
  24. func.__name__, *self.version
  25. )
  26. if is_compatible:
  27. @wraps(func)
  28. def wrapped(*args, **kwargs):
  29. return func(*args, **kwargs)
  30. else:
  31. @wraps(func)
  32. def wrapped(*args, **kwargs):
  33. raise UnsupportedGEOSVersionError(msg)
  34. doc = wrapped.__doc__
  35. if doc:
  36. # Insert the message at the first double newline
  37. position = doc.find("\n\n") + 2
  38. # Figure out the indentation level
  39. indent = 0
  40. while True:
  41. if doc[position + indent] == " ":
  42. indent += 1
  43. else:
  44. break
  45. wrapped.__doc__ = doc.replace(
  46. "\n\n", "\n\n{}.. note:: {}\n\n".format(" " * indent, msg), 1
  47. )
  48. return wrapped
  49. def multithreading_enabled(func):
  50. """Enable multithreading.
  51. To do this, the writable flags of object type ndarrays are set to False.
  52. NB: multithreading also requires the GIL to be released, which is done in
  53. the C extension (ufuncs.c).
  54. """
  55. @wraps(func)
  56. def wrapped(*args, **kwargs):
  57. array_args = [
  58. arg for arg in args if isinstance(arg, np.ndarray) and arg.dtype == object
  59. ] + [
  60. arg
  61. for name, arg in kwargs.items()
  62. if name not in {"where", "out"}
  63. and isinstance(arg, np.ndarray)
  64. and arg.dtype == object
  65. ]
  66. old_flags = [arr.flags.writeable for arr in array_args]
  67. try:
  68. for arr in array_args:
  69. arr.flags.writeable = False
  70. return func(*args, **kwargs)
  71. finally:
  72. for arr, old_flag in zip(array_args, old_flags):
  73. arr.flags.writeable = old_flag
  74. return wrapped
  75. def deprecate_positional(
  76. should_be_kwargs: Iterable[str],
  77. category: type[Warning] = DeprecationWarning,
  78. ):
  79. """Show warning if positional arguments are used that should be keyword.
  80. Parameters
  81. ----------
  82. should_be_kwargs : Iterable[str]
  83. Names of parameters that should be passed as keyword arguments.
  84. category : type[Warning], optional (default: DeprecationWarning)
  85. Warning category to use for deprecation warnings.
  86. Returns
  87. -------
  88. callable
  89. Decorator function that adds positional argument deprecation warnings.
  90. Examples
  91. --------
  92. >>> from shapely.decorators import deprecate_positional
  93. >>> @deprecate_positional(['b', 'c'])
  94. ... def example(a, b, c=None):
  95. ... return a, b, c
  96. ...
  97. >>> example(1, 2) # doctest: +SKIP
  98. DeprecationWarning: positional argument `b` for `example` is deprecated. ...
  99. (1, 2, None)
  100. >>> example(1, b=2) # No warnings
  101. (1, 2, None)
  102. """
  103. def decorator(func: Callable):
  104. code = unwrap(func).__code__
  105. # positional parameters are the first co_argcount names
  106. pos_names = code.co_varnames[: code.co_argcount]
  107. # build a name -> index map
  108. name_to_idx = {name: idx for idx, name in enumerate(pos_names)}
  109. # pick out only those names we care about
  110. deprecate_positions = [
  111. (name_to_idx[name], name)
  112. for name in should_be_kwargs
  113. if name in name_to_idx
  114. ]
  115. # early exit if there are no deprecated positional args
  116. if not deprecate_positions:
  117. return func
  118. # earliest position where a warning could occur
  119. warn_from = min(deprecate_positions)[0]
  120. @lru_cache(10)
  121. def make_msg(n_args: int):
  122. used = [name for idx, name in deprecate_positions if idx < n_args]
  123. if len(used) == 1:
  124. args_txt = f"`{used[0]}`"
  125. plr = ""
  126. isare = "is"
  127. else:
  128. plr = "s"
  129. isare = "are"
  130. if len(used) == 2:
  131. args_txt = " and ".join(f"`{u}`" for u in used)
  132. else:
  133. args_txt = ", ".join(f"`{u}`" for u in used[:-1])
  134. args_txt += f", and `{used[-1]}`"
  135. return (
  136. f"positional argument{plr} {args_txt} for `{func.__name__}` "
  137. f"{isare} deprecated. Please use keyword argument{plr} instead."
  138. )
  139. @wraps(func)
  140. def wrapper(*args, **kwargs):
  141. result = func(*args, **kwargs)
  142. n = len(args)
  143. if n > warn_from:
  144. warnings.warn(make_msg(n), category=category, stacklevel=2)
  145. return result
  146. return wrapper
  147. return decorator