basisdependent.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. from __future__ import annotations
  2. from typing import TYPE_CHECKING
  3. from sympy.simplify import simplify as simp, trigsimp as tsimp # type: ignore
  4. from sympy.core.decorators import call_highest_priority, _sympifyit
  5. from sympy.core.assumptions import StdFactKB
  6. from sympy.core.function import diff as df
  7. from sympy.integrals.integrals import Integral
  8. from sympy.polys.polytools import factor as fctr
  9. from sympy.core import S, Add, Mul
  10. from sympy.core.expr import Expr
  11. if TYPE_CHECKING:
  12. from sympy.vector.vector import BaseVector
  13. class BasisDependent(Expr):
  14. """
  15. Super class containing functionality common to vectors and
  16. dyadics.
  17. Named so because the representation of these quantities in
  18. sympy.vector is dependent on the basis they are expressed in.
  19. """
  20. zero: BasisDependentZero
  21. @call_highest_priority('__radd__')
  22. def __add__(self, other):
  23. return self._add_func(self, other)
  24. @call_highest_priority('__add__')
  25. def __radd__(self, other):
  26. return self._add_func(other, self)
  27. @call_highest_priority('__rsub__')
  28. def __sub__(self, other):
  29. return self._add_func(self, -other)
  30. @call_highest_priority('__sub__')
  31. def __rsub__(self, other):
  32. return self._add_func(other, -self)
  33. @_sympifyit('other', NotImplemented)
  34. @call_highest_priority('__rmul__')
  35. def __mul__(self, other):
  36. return self._mul_func(self, other)
  37. @_sympifyit('other', NotImplemented)
  38. @call_highest_priority('__mul__')
  39. def __rmul__(self, other):
  40. return self._mul_func(other, self)
  41. def __neg__(self):
  42. return self._mul_func(S.NegativeOne, self)
  43. @_sympifyit('other', NotImplemented)
  44. @call_highest_priority('__rtruediv__')
  45. def __truediv__(self, other):
  46. return self._div_helper(other)
  47. @call_highest_priority('__truediv__')
  48. def __rtruediv__(self, other):
  49. return TypeError("Invalid divisor for division")
  50. def evalf(self, n=15, subs=None, maxn=100, chop=False, strict=False, quad=None, verbose=False):
  51. """
  52. Implements the SymPy evalf routine for this quantity.
  53. evalf's documentation
  54. =====================
  55. """
  56. options = {'subs':subs, 'maxn':maxn, 'chop':chop, 'strict':strict,
  57. 'quad':quad, 'verbose':verbose}
  58. vec = self.zero
  59. for k, v in self.components.items():
  60. vec += v.evalf(n, **options) * k
  61. return vec
  62. evalf.__doc__ += Expr.evalf.__doc__ # type: ignore
  63. n = evalf # type: ignore
  64. def simplify(self, **kwargs):
  65. """
  66. Implements the SymPy simplify routine for this quantity.
  67. simplify's documentation
  68. ========================
  69. """
  70. simp_components = [simp(v, **kwargs) * k for
  71. k, v in self.components.items()]
  72. return self._add_func(*simp_components)
  73. simplify.__doc__ += simp.__doc__ # type: ignore
  74. def trigsimp(self, **opts):
  75. """
  76. Implements the SymPy trigsimp routine, for this quantity.
  77. trigsimp's documentation
  78. ========================
  79. """
  80. trig_components = [tsimp(v, **opts) * k for
  81. k, v in self.components.items()]
  82. return self._add_func(*trig_components)
  83. trigsimp.__doc__ += tsimp.__doc__ # type: ignore
  84. def _eval_simplify(self, **kwargs):
  85. return self.simplify(**kwargs)
  86. def _eval_trigsimp(self, **opts):
  87. return self.trigsimp(**opts)
  88. def _eval_derivative(self, wrt):
  89. return self.diff(wrt)
  90. def _eval_Integral(self, *symbols, **assumptions):
  91. integral_components = [Integral(v, *symbols, **assumptions) * k
  92. for k, v in self.components.items()]
  93. return self._add_func(*integral_components)
  94. def as_numer_denom(self):
  95. """
  96. Returns the expression as a tuple wrt the following
  97. transformation -
  98. expression -> a/b -> a, b
  99. """
  100. return self, S.One
  101. def factor(self, *args, **kwargs):
  102. """
  103. Implements the SymPy factor routine, on the scalar parts
  104. of a basis-dependent expression.
  105. factor's documentation
  106. ========================
  107. """
  108. fctr_components = [fctr(v, *args, **kwargs) * k for
  109. k, v in self.components.items()]
  110. return self._add_func(*fctr_components)
  111. factor.__doc__ += fctr.__doc__ # type: ignore
  112. def as_coeff_Mul(self, rational=False):
  113. """Efficiently extract the coefficient of a product."""
  114. return (S.One, self)
  115. def as_coeff_add(self, *deps):
  116. """Efficiently extract the coefficient of a summation."""
  117. return 0, tuple(x * self.components[x] for x in self.components)
  118. def diff(self, *args, **kwargs):
  119. """
  120. Implements the SymPy diff routine, for vectors.
  121. diff's documentation
  122. ========================
  123. """
  124. for x in args:
  125. if isinstance(x, BasisDependent):
  126. raise TypeError("Invalid arg for differentiation")
  127. diff_components = [df(v, *args, **kwargs) * k for
  128. k, v in self.components.items()]
  129. return self._add_func(*diff_components)
  130. diff.__doc__ += df.__doc__ # type: ignore
  131. def doit(self, **hints):
  132. """Calls .doit() on each term in the Dyadic"""
  133. doit_components = [self.components[x].doit(**hints) * x
  134. for x in self.components]
  135. return self._add_func(*doit_components)
  136. class BasisDependentAdd(BasisDependent, Add):
  137. """
  138. Denotes sum of basis dependent quantities such that they cannot
  139. be expressed as base or Mul instances.
  140. """
  141. def __new__(cls, *args, **options):
  142. components = {}
  143. # Check each arg and simultaneously learn the components
  144. for arg in args:
  145. if not isinstance(arg, cls._expr_type):
  146. if isinstance(arg, Mul):
  147. arg = cls._mul_func(*(arg.args))
  148. elif isinstance(arg, Add):
  149. arg = cls._add_func(*(arg.args))
  150. else:
  151. raise TypeError(str(arg) +
  152. " cannot be interpreted correctly")
  153. # If argument is zero, ignore
  154. if arg == cls.zero:
  155. continue
  156. # Else, update components accordingly
  157. for x in arg.components:
  158. components[x] = components.get(x, 0) + arg.components[x]
  159. temp = list(components.keys())
  160. for x in temp:
  161. if components[x] == 0:
  162. del components[x]
  163. # Handle case of zero vector
  164. if len(components) == 0:
  165. return cls.zero
  166. # Build object
  167. newargs = [x * components[x] for x in components]
  168. obj = super().__new__(cls, *newargs, **options)
  169. if isinstance(obj, Mul):
  170. return cls._mul_func(*obj.args)
  171. assumptions = {'commutative': True}
  172. obj._assumptions = StdFactKB(assumptions)
  173. obj._components = components
  174. obj._sys = (list(components.keys()))[0]._sys
  175. return obj
  176. class BasisDependentMul(BasisDependent, Mul):
  177. """
  178. Denotes product of base- basis dependent quantity with a scalar.
  179. """
  180. def __new__(cls, *args, **options):
  181. obj = cls._new(*args, **options)
  182. return obj
  183. def _new_rawargs(self, *args):
  184. # XXX: This is needed because Add.flatten() uses it but the default
  185. # implementation does not work for Vectors because they assign
  186. # attributes outside of .args.
  187. return type(self)(*args)
  188. @classmethod
  189. def _new(cls, *args, **options):
  190. from sympy.vector import Cross, Dot, Curl, Gradient
  191. count = 0
  192. measure_number = S.One
  193. zeroflag = False
  194. extra_args = []
  195. # Determine the component and check arguments
  196. # Also keep a count to ensure two vectors aren't
  197. # being multiplied
  198. for arg in args:
  199. if isinstance(arg, cls._zero_func):
  200. count += 1
  201. zeroflag = True
  202. elif arg == S.Zero:
  203. zeroflag = True
  204. elif isinstance(arg, (cls._base_func, cls._mul_func)):
  205. count += 1
  206. expr = arg._base_instance
  207. measure_number *= arg._measure_number
  208. elif isinstance(arg, cls._add_func):
  209. count += 1
  210. expr = arg
  211. elif isinstance(arg, (Cross, Dot, Curl, Gradient)):
  212. extra_args.append(arg)
  213. else:
  214. measure_number *= arg
  215. # Make sure incompatible types weren't multiplied
  216. if count > 1:
  217. raise ValueError("Invalid multiplication")
  218. elif count == 0:
  219. return Mul(*args, **options)
  220. # Handle zero vector case
  221. if zeroflag:
  222. return cls.zero
  223. # If one of the args was a VectorAdd, return an
  224. # appropriate VectorAdd instance
  225. if isinstance(expr, cls._add_func):
  226. newargs = [cls._mul_func(measure_number, x) for
  227. x in expr.args]
  228. return cls._add_func(*newargs)
  229. obj = super().__new__(cls, measure_number,
  230. expr._base_instance,
  231. *extra_args,
  232. **options)
  233. if isinstance(obj, Add):
  234. return cls._add_func(*obj.args)
  235. obj._base_instance = expr._base_instance
  236. obj._measure_number = measure_number
  237. assumptions = {'commutative': True}
  238. obj._assumptions = StdFactKB(assumptions)
  239. obj._components = {expr._base_instance: measure_number}
  240. obj._sys = expr._base_instance._sys
  241. return obj
  242. def _sympystr(self, printer):
  243. measure_str = printer._print(self._measure_number)
  244. if ('(' in measure_str or '-' in measure_str or
  245. '+' in measure_str):
  246. measure_str = '(' + measure_str + ')'
  247. return measure_str + '*' + printer._print(self._base_instance)
  248. class BasisDependentZero(BasisDependent):
  249. """
  250. Class to denote a zero basis dependent instance.
  251. """
  252. components: dict['BaseVector', Expr] = {}
  253. _latex_form: str
  254. def __new__(cls):
  255. obj = super().__new__(cls)
  256. # Pre-compute a specific hash value for the zero vector
  257. # Use the same one always
  258. obj._hash = (S.Zero, cls).__hash__()
  259. return obj
  260. def __hash__(self):
  261. return self._hash
  262. @call_highest_priority('__req__')
  263. def __eq__(self, other):
  264. return isinstance(other, self._zero_func)
  265. __req__ = __eq__
  266. @call_highest_priority('__radd__')
  267. def __add__(self, other):
  268. if isinstance(other, self._expr_type):
  269. return other
  270. else:
  271. raise TypeError("Invalid argument types for addition")
  272. @call_highest_priority('__add__')
  273. def __radd__(self, other):
  274. if isinstance(other, self._expr_type):
  275. return other
  276. else:
  277. raise TypeError("Invalid argument types for addition")
  278. @call_highest_priority('__rsub__')
  279. def __sub__(self, other):
  280. if isinstance(other, self._expr_type):
  281. return -other
  282. else:
  283. raise TypeError("Invalid argument types for subtraction")
  284. @call_highest_priority('__sub__')
  285. def __rsub__(self, other):
  286. if isinstance(other, self._expr_type):
  287. return other
  288. else:
  289. raise TypeError("Invalid argument types for subtraction")
  290. def __neg__(self):
  291. return self
  292. def normalize(self):
  293. """
  294. Returns the normalized version of this vector.
  295. """
  296. return self
  297. def _sympystr(self, printer):
  298. return '0'