textplot.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. from sympy.core.numbers import Float
  2. from sympy.core.symbol import Dummy
  3. from sympy.utilities.lambdify import lambdify
  4. import math
  5. def is_valid(x):
  6. """Check if a floating point number is valid"""
  7. if x is None:
  8. return False
  9. if isinstance(x, complex):
  10. return False
  11. return not math.isinf(x) and not math.isnan(x)
  12. def rescale(y, W, H, mi, ma):
  13. """Rescale the given array `y` to fit into the integer values
  14. between `0` and `H-1` for the values between ``mi`` and ``ma``.
  15. """
  16. y_new = []
  17. norm = ma - mi
  18. offset = (ma + mi) / 2
  19. for x in range(W):
  20. if is_valid(y[x]):
  21. normalized = (y[x] - offset) / norm
  22. if not is_valid(normalized):
  23. y_new.append(None)
  24. else:
  25. rescaled = Float((normalized*H + H/2) * (H-1)/H).round()
  26. rescaled = int(rescaled)
  27. y_new.append(rescaled)
  28. else:
  29. y_new.append(None)
  30. return y_new
  31. def linspace(start, stop, num):
  32. return [start + (stop - start) * x / (num-1) for x in range(num)]
  33. def textplot_str(expr, a, b, W=55, H=21):
  34. """Generator for the lines of the plot"""
  35. free = expr.free_symbols
  36. if len(free) > 1:
  37. raise ValueError(
  38. "The expression must have a single variable. (Got {})"
  39. .format(free))
  40. x = free.pop() if free else Dummy()
  41. f = lambdify([x], expr)
  42. if isinstance(a, complex):
  43. if a.imag == 0:
  44. a = a.real
  45. if isinstance(b, complex):
  46. if b.imag == 0:
  47. b = b.real
  48. a = float(a)
  49. b = float(b)
  50. # Calculate function values
  51. x = linspace(a, b, W)
  52. y = []
  53. for val in x:
  54. try:
  55. y.append(f(val))
  56. # Not sure what exceptions to catch here or why...
  57. except (ValueError, TypeError, ZeroDivisionError):
  58. y.append(None)
  59. # Normalize height to screen space
  60. y_valid = list(filter(is_valid, y))
  61. if y_valid:
  62. ma = max(y_valid)
  63. mi = min(y_valid)
  64. if ma == mi:
  65. if ma:
  66. mi, ma = sorted([0, 2*ma])
  67. else:
  68. mi, ma = -1, 1
  69. else:
  70. mi, ma = -1, 1
  71. y_range = ma - mi
  72. precision = math.floor(math.log10(y_range)) - 1
  73. precision *= -1
  74. mi = round(mi, precision)
  75. ma = round(ma, precision)
  76. y = rescale(y, W, H, mi, ma)
  77. y_bins = linspace(mi, ma, H)
  78. # Draw plot
  79. margin = 7
  80. for h in range(H - 1, -1, -1):
  81. s = [' '] * W
  82. for i in range(W):
  83. if y[i] == h:
  84. if (i == 0 or y[i - 1] == h - 1) and (i == W - 1 or y[i + 1] == h + 1):
  85. s[i] = '/'
  86. elif (i == 0 or y[i - 1] == h + 1) and (i == W - 1 or y[i + 1] == h - 1):
  87. s[i] = '\\'
  88. else:
  89. s[i] = '.'
  90. if h == 0:
  91. for i in range(W):
  92. s[i] = '_'
  93. # Print y values
  94. if h in (0, H//2, H - 1):
  95. prefix = ("%g" % y_bins[h]).rjust(margin)[:margin]
  96. else:
  97. prefix = " "*margin
  98. s = "".join(s)
  99. if h == H//2:
  100. s = s.replace(" ", "-")
  101. yield prefix + " |" + s
  102. # Print x values
  103. bottom = " " * (margin + 2)
  104. bottom += ("%g" % x[0]).ljust(W//2)
  105. if W % 2 == 1:
  106. bottom += ("%g" % x[W//2]).ljust(W//2)
  107. else:
  108. bottom += ("%g" % x[W//2]).ljust(W//2-1)
  109. bottom += "%g" % x[-1]
  110. yield bottom
  111. def textplot(expr, a, b, W=55, H=21):
  112. r"""
  113. Print a crude ASCII art plot of the SymPy expression 'expr' (which
  114. should contain a single symbol, e.g. x or something else) over the
  115. interval [a, b].
  116. Examples
  117. ========
  118. >>> from sympy import Symbol, sin
  119. >>> from sympy.plotting import textplot
  120. >>> t = Symbol('t')
  121. >>> textplot(sin(t)*t, 0, 15)
  122. 14 | ...
  123. | .
  124. | .
  125. | .
  126. | .
  127. | ...
  128. | / . .
  129. | /
  130. | / .
  131. | . . .
  132. 1.5 |----.......--------------------------------------------
  133. |.... \ . .
  134. | \ / .
  135. | .. / .
  136. | \ / .
  137. | ....
  138. | .
  139. | . .
  140. |
  141. | . .
  142. -11 |_______________________________________________________
  143. 0 7.5 15
  144. """
  145. for line in textplot_str(expr, a, b, W, H):
  146. print(line)