_utils.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # Adapted from https://stackoverflow.com/a/9558001/2536294
  2. import ast
  3. import functools
  4. import operator as op
  5. from dataclasses import dataclass
  6. from ._multiprocessing_helpers import mp
  7. if mp is not None:
  8. from .externals.loky.process_executor import _ExceptionWithTraceback
  9. # supported operators
  10. operators = {
  11. ast.Add: op.add,
  12. ast.Sub: op.sub,
  13. ast.Mult: op.mul,
  14. ast.Div: op.truediv,
  15. ast.FloorDiv: op.floordiv,
  16. ast.Mod: op.mod,
  17. ast.Pow: op.pow,
  18. ast.USub: op.neg,
  19. }
  20. def eval_expr(expr):
  21. """Somewhat safely evaluate an arithmetic expression.
  22. >>> eval_expr('2*6')
  23. 12
  24. >>> eval_expr('2**6')
  25. 64
  26. >>> eval_expr('1 + 2*3**(4) / (6 + -7)')
  27. -161.0
  28. Raises ValueError if the expression is invalid, too long
  29. or its computation involves too large values.
  30. """
  31. # Restrict the length of the expression to avoid potential Python crashes
  32. # as per the documentation of ast.parse.
  33. max_length = 30
  34. if len(expr) > max_length:
  35. raise ValueError(
  36. f"Expression {expr[:max_length]!r}... is too long. "
  37. f"Max length is {max_length}, got {len(expr)}."
  38. )
  39. try:
  40. return eval_(ast.parse(expr, mode="eval").body)
  41. except (TypeError, SyntaxError, OverflowError, KeyError) as e:
  42. raise ValueError(
  43. f"{expr!r} is not a valid or supported arithmetic expression."
  44. ) from e
  45. def limit(max_=None):
  46. """Return decorator that limits allowed returned values."""
  47. def decorator(func):
  48. @functools.wraps(func)
  49. def wrapper(*args, **kwargs):
  50. ret = func(*args, **kwargs)
  51. try:
  52. mag = abs(ret)
  53. except TypeError:
  54. pass # not applicable
  55. else:
  56. if mag > max_:
  57. raise ValueError(
  58. f"Numeric literal {ret} is too large, max is {max_}."
  59. )
  60. return ret
  61. return wrapper
  62. return decorator
  63. @limit(max_=10**6)
  64. def eval_(node):
  65. if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
  66. return node.value
  67. elif isinstance(node, ast.BinOp): # <left> <operator> <right>
  68. return operators[type(node.op)](eval_(node.left), eval_(node.right))
  69. elif isinstance(node, ast.UnaryOp): # <operator> <operand> e.g., -1
  70. return operators[type(node.op)](eval_(node.operand))
  71. else:
  72. raise TypeError(node)
  73. @dataclass(frozen=True)
  74. class _Sentinel:
  75. """A sentinel to mark a parameter as not explicitly set"""
  76. default_value: object
  77. def __repr__(self):
  78. return f"default({self.default_value!r})"
  79. class _TracebackCapturingWrapper:
  80. """Protect function call and return error with traceback."""
  81. def __init__(self, func):
  82. self.func = func
  83. def __call__(self, **kwargs):
  84. try:
  85. return self.func(**kwargs)
  86. except BaseException as e:
  87. return _ExceptionWithTraceback(e)
  88. def _retrieve_traceback_capturing_wrapped_call(out):
  89. if isinstance(out, _ExceptionWithTraceback):
  90. rebuild, args = out.__reduce__()
  91. out = rebuild(*args)
  92. if isinstance(out, BaseException):
  93. raise out
  94. return out