print_coercion_tables.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. #!/usr/bin/env python3
  2. """Prints type-coercion tables for the built-in NumPy types
  3. """
  4. from collections import namedtuple
  5. import numpy as np
  6. from numpy._core.numerictypes import obj2sctype
  7. # Generic object that can be added, but doesn't do anything else
  8. class GenericObject:
  9. def __init__(self, v):
  10. self.v = v
  11. def __add__(self, other):
  12. return self
  13. def __radd__(self, other):
  14. return self
  15. dtype = np.dtype('O')
  16. def print_cancast_table(ntypes):
  17. print('X', end=' ')
  18. for char in ntypes:
  19. print(char, end=' ')
  20. print()
  21. for row in ntypes:
  22. print(row, end=' ')
  23. for col in ntypes:
  24. if np.can_cast(row, col, "equiv"):
  25. cast = "#"
  26. elif np.can_cast(row, col, "safe"):
  27. cast = "="
  28. elif np.can_cast(row, col, "same_kind"):
  29. cast = "~"
  30. elif np.can_cast(row, col, "unsafe"):
  31. cast = "."
  32. else:
  33. cast = " "
  34. print(cast, end=' ')
  35. print()
  36. def print_coercion_table(ntypes, inputfirstvalue, inputsecondvalue, firstarray,
  37. use_promote_types=False):
  38. print('+', end=' ')
  39. for char in ntypes:
  40. print(char, end=' ')
  41. print()
  42. for row in ntypes:
  43. if row == 'O':
  44. rowtype = GenericObject
  45. else:
  46. rowtype = obj2sctype(row)
  47. print(row, end=' ')
  48. for col in ntypes:
  49. if col == 'O':
  50. coltype = GenericObject
  51. else:
  52. coltype = obj2sctype(col)
  53. try:
  54. if firstarray:
  55. rowvalue = np.array([rowtype(inputfirstvalue)], dtype=rowtype)
  56. else:
  57. rowvalue = rowtype(inputfirstvalue)
  58. colvalue = coltype(inputsecondvalue)
  59. if use_promote_types:
  60. char = np.promote_types(rowvalue.dtype, colvalue.dtype).char
  61. else:
  62. value = np.add(rowvalue, colvalue)
  63. if isinstance(value, np.ndarray):
  64. char = value.dtype.char
  65. else:
  66. char = np.dtype(type(value)).char
  67. except ValueError:
  68. char = '!'
  69. except OverflowError:
  70. char = '@'
  71. except TypeError:
  72. char = '#'
  73. print(char, end=' ')
  74. print()
  75. def print_new_cast_table(*, can_cast=True, legacy=False, flags=False):
  76. """Prints new casts, the values given are default "can-cast" values, not
  77. actual ones.
  78. """
  79. from numpy._core._multiarray_tests import get_all_cast_information
  80. cast_table = {
  81. -1: " ",
  82. 0: "#", # No cast (classify as equivalent here)
  83. 1: "#", # equivalent casting
  84. 2: "=", # safe casting
  85. 3: "~", # same-kind casting
  86. 4: ".", # unsafe casting
  87. }
  88. flags_table = {
  89. 0: "▗", 7: "█",
  90. 1: "▚", 2: "▐", 4: "▄",
  91. 3: "▜", 5: "▙",
  92. 6: "▟",
  93. }
  94. cast_info = namedtuple("cast_info", ["can_cast", "legacy", "flags"])
  95. no_cast_info = cast_info(" ", " ", " ")
  96. casts = get_all_cast_information()
  97. table = {}
  98. dtypes = set()
  99. for cast in casts:
  100. dtypes.add(cast["from"])
  101. dtypes.add(cast["to"])
  102. if cast["from"] not in table:
  103. table[cast["from"]] = {}
  104. to_dict = table[cast["from"]]
  105. can_cast = cast_table[cast["casting"]]
  106. legacy = "L" if cast["legacy"] else "."
  107. flags = 0
  108. if cast["requires_pyapi"]:
  109. flags |= 1
  110. if cast["supports_unaligned"]:
  111. flags |= 2
  112. if cast["no_floatingpoint_errors"]:
  113. flags |= 4
  114. flags = flags_table[flags]
  115. to_dict[cast["to"]] = cast_info(can_cast=can_cast, legacy=legacy, flags=flags)
  116. # The np.dtype(x.type) is a bit strange, because dtype classes do
  117. # not expose much yet.
  118. types = np.typecodes["All"]
  119. def sorter(x):
  120. # This is a bit weird hack, to get a table as close as possible to
  121. # the one printing all typecodes (but expecting user-dtypes).
  122. dtype = np.dtype(x.type)
  123. try:
  124. indx = types.index(dtype.char)
  125. except ValueError:
  126. indx = np.inf
  127. return (indx, dtype.char)
  128. dtypes = sorted(dtypes, key=sorter)
  129. def print_table(field="can_cast"):
  130. print('X', end=' ')
  131. for dt in dtypes:
  132. print(np.dtype(dt.type).char, end=' ')
  133. print()
  134. for from_dt in dtypes:
  135. print(np.dtype(from_dt.type).char, end=' ')
  136. row = table.get(from_dt, {})
  137. for to_dt in dtypes:
  138. print(getattr(row.get(to_dt, no_cast_info), field), end=' ')
  139. print()
  140. if can_cast:
  141. # Print the actual table:
  142. print()
  143. print("Casting: # is equivalent, = is safe, ~ is same-kind, and . is unsafe")
  144. print()
  145. print_table("can_cast")
  146. if legacy:
  147. print()
  148. print("L denotes a legacy cast . a non-legacy one.")
  149. print()
  150. print_table("legacy")
  151. if flags:
  152. print()
  153. print(f"{flags_table[0]}: no flags, "
  154. f"{flags_table[1]}: PyAPI, "
  155. f"{flags_table[2]}: supports unaligned, "
  156. f"{flags_table[4]}: no-float-errors")
  157. print()
  158. print_table("flags")
  159. if __name__ == '__main__':
  160. print("can cast")
  161. print_cancast_table(np.typecodes['All'])
  162. print()
  163. print("In these tables, ValueError is '!', OverflowError is '@', TypeError is '#'")
  164. print()
  165. print("scalar + scalar")
  166. print_coercion_table(np.typecodes['All'], 0, 0, False)
  167. print()
  168. print("scalar + neg scalar")
  169. print_coercion_table(np.typecodes['All'], 0, -1, False)
  170. print()
  171. print("array + scalar")
  172. print_coercion_table(np.typecodes['All'], 0, 0, True)
  173. print()
  174. print("array + neg scalar")
  175. print_coercion_table(np.typecodes['All'], 0, -1, True)
  176. print()
  177. print("promote_types")
  178. print_coercion_table(np.typecodes['All'], 0, 0, False, True)
  179. print("New casting type promotion:")
  180. print_new_cast_table(can_cast=True, legacy=True, flags=True)