transforms.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. """Transforms that are always applied to quantum expressions.
  2. This module uses the kind and _constructor_postprocessor_mapping APIs
  3. to transform different combinations of Operators, Bras, and Kets into
  4. Inner/Outer/TensorProducts. These transformations are registered
  5. with the postprocessing API of core classes like `Mul` and `Pow` and
  6. are always applied to any expression involving Bras, Kets, and
  7. Operators. This API replaces the custom `__mul__` and `__pow__`
  8. methods of the quantum classes, which were found to be inconsistent.
  9. THIS IS EXPERIMENTAL.
  10. """
  11. from sympy.core.basic import Basic
  12. from sympy.core.expr import Expr
  13. from sympy.core.mul import Mul
  14. from sympy.core.singleton import S
  15. from sympy.multipledispatch.dispatcher import (
  16. Dispatcher, ambiguity_register_error_ignore_dup
  17. )
  18. from sympy.utilities.misc import debug
  19. from sympy.physics.quantum.innerproduct import InnerProduct
  20. from sympy.physics.quantum.kind import KetKind, BraKind, OperatorKind
  21. from sympy.physics.quantum.operator import (
  22. OuterProduct, IdentityOperator, Operator
  23. )
  24. from sympy.physics.quantum.state import BraBase, KetBase, StateBase
  25. from sympy.physics.quantum.tensorproduct import TensorProduct
  26. #-----------------------------------------------------------------------------
  27. # Multipledispatch based transformed for Mul and Pow
  28. #-----------------------------------------------------------------------------
  29. _transform_state_pair = Dispatcher('_transform_state_pair')
  30. """Transform a pair of expression in a Mul to their canonical form.
  31. All functions that are registered with this dispatcher need to take
  32. two inputs and return either tuple of transformed outputs, or None if no
  33. transform is applied. The output tuple is inserted into the right place
  34. of the ``Mul`` that is being put into canonical form. It works something like
  35. the following:
  36. ``Mul(a, b, c, d, e, f) -> Mul(*(_transform_state_pair(a, b) + (c, d, e, f))))``
  37. The transforms here are always applied when quantum objects are multiplied.
  38. THIS IS EXPERIMENTAL.
  39. However, users of ``sympy.physics.quantum`` can import this dispatcher and
  40. register their own transforms to control the canonical form of products
  41. of quantum expressions.
  42. """
  43. @_transform_state_pair.register(Expr, Expr)
  44. def _transform_expr(a, b):
  45. """Default transformer that does nothing for base types."""
  46. return None
  47. # The identity times anything is the anything.
  48. _transform_state_pair.add(
  49. (IdentityOperator, Expr),
  50. lambda x, y: (y,),
  51. on_ambiguity=ambiguity_register_error_ignore_dup
  52. )
  53. _transform_state_pair.add(
  54. (Expr, IdentityOperator),
  55. lambda x, y: (x,),
  56. on_ambiguity=ambiguity_register_error_ignore_dup
  57. )
  58. _transform_state_pair.add(
  59. (IdentityOperator, IdentityOperator),
  60. lambda x, y: S.One,
  61. on_ambiguity=ambiguity_register_error_ignore_dup
  62. )
  63. @_transform_state_pair.register(BraBase, KetBase)
  64. def _transform_bra_ket(a, b):
  65. """Transform a bra*ket -> InnerProduct(bra, ket)."""
  66. return (InnerProduct(a, b),)
  67. @_transform_state_pair.register(KetBase, BraBase)
  68. def _transform_ket_bra(a, b):
  69. """Transform a keT*bra -> OuterProduct(ket, bra)."""
  70. return (OuterProduct(a, b),)
  71. @_transform_state_pair.register(KetBase, KetBase)
  72. def _transform_ket_ket(a, b):
  73. """Raise a TypeError if a user tries to multiply two kets.
  74. Multiplication based on `*` is not a shorthand for tensor products.
  75. """
  76. raise TypeError(
  77. 'Multiplication of two kets is not allowed. Use TensorProduct instead.'
  78. )
  79. @_transform_state_pair.register(BraBase, BraBase)
  80. def _transform_bra_bra(a, b):
  81. """Raise a TypeError if a user tries to multiply two bras.
  82. Multiplication based on `*` is not a shorthand for tensor products.
  83. """
  84. raise TypeError(
  85. 'Multiplication of two bras is not allowed. Use TensorProduct instead.'
  86. )
  87. @_transform_state_pair.register(OuterProduct, KetBase)
  88. def _transform_op_ket(a, b):
  89. return (InnerProduct(a.bra, b), a.ket)
  90. @_transform_state_pair.register(BraBase, OuterProduct)
  91. def _transform_bra_op(a, b):
  92. return (InnerProduct(a, b.ket), b.bra)
  93. @_transform_state_pair.register(TensorProduct, KetBase)
  94. def _transform_tp_ket(a, b):
  95. """Raise a TypeError if a user tries to multiply TensorProduct(*kets)*ket.
  96. Multiplication based on `*` is not a shorthand for tensor products.
  97. """
  98. if a.kind == KetKind:
  99. raise TypeError(
  100. 'Multiplication of TensorProduct(*kets)*ket is invalid.'
  101. )
  102. @_transform_state_pair.register(KetBase, TensorProduct)
  103. def _transform_ket_tp(a, b):
  104. """Raise a TypeError if a user tries to multiply ket*TensorProduct(*kets).
  105. Multiplication based on `*` is not a shorthand for tensor products.
  106. """
  107. if b.kind == KetKind:
  108. raise TypeError(
  109. 'Multiplication of ket*TensorProduct(*kets) is invalid.'
  110. )
  111. @_transform_state_pair.register(TensorProduct, BraBase)
  112. def _transform_tp_bra(a, b):
  113. """Raise a TypeError if a user tries to multiply TensorProduct(*bras)*bra.
  114. Multiplication based on `*` is not a shorthand for tensor products.
  115. """
  116. if a.kind == BraKind:
  117. raise TypeError(
  118. 'Multiplication of TensorProduct(*bras)*bra is invalid.'
  119. )
  120. @_transform_state_pair.register(BraBase, TensorProduct)
  121. def _transform_bra_tp(a, b):
  122. """Raise a TypeError if a user tries to multiply bra*TensorProduct(*bras).
  123. Multiplication based on `*` is not a shorthand for tensor products.
  124. """
  125. if b.kind == BraKind:
  126. raise TypeError(
  127. 'Multiplication of bra*TensorProduct(*bras) is invalid.'
  128. )
  129. @_transform_state_pair.register(TensorProduct, TensorProduct)
  130. def _transform_tp_tp(a, b):
  131. """Combine a product of tensor products if their number of args matches."""
  132. debug('_transform_tp_tp', a, b)
  133. if len(a.args) == len(b.args):
  134. if a.kind == BraKind and b.kind == KetKind:
  135. return tuple([InnerProduct(i, j) for (i, j) in zip(a.args, b.args)])
  136. else:
  137. return (TensorProduct(*(i*j for (i, j) in zip(a.args, b.args))), )
  138. @_transform_state_pair.register(OuterProduct, OuterProduct)
  139. def _transform_op_op(a, b):
  140. """Extract an inner produt from a product of outer products."""
  141. return (InnerProduct(a.bra, b.ket), OuterProduct(a.ket, b.bra))
  142. #-----------------------------------------------------------------------------
  143. # Postprocessing transforms for Mul and Pow
  144. #-----------------------------------------------------------------------------
  145. def _postprocess_state_mul(expr):
  146. """Transform a ``Mul`` of quantum expressions into canonical form.
  147. This function is registered ``_constructor_postprocessor_mapping`` as a
  148. transformer for ``Mul``. This means that every time a quantum expression
  149. is multiplied, this function will be called to transform it into canonical
  150. form as defined by the binary functions registered with
  151. ``_transform_state_pair``.
  152. The algorithm of this function is as follows. It walks the args
  153. of the input ``Mul`` from left to right and calls ``_transform_state_pair``
  154. on every overlapping pair of args. Each time ``_transform_state_pair``
  155. is called it can return a tuple of items or None. If None, the pair isn't
  156. transformed. If a tuple, then the last element of the tuple goes back into
  157. the args to be transformed again and the others are extended onto the result
  158. args list.
  159. The algorithm can be visualized in the following table:
  160. step result args
  161. ============================================================================
  162. #0 [] [a, b, c, d, e, f]
  163. #1 [] [T(a,b), c, d, e, f]
  164. #2 [T(a,b)[:-1]] [T(a,b)[-1], c, d, e, f]
  165. #3 [T(a,b)[:-1]] [T(T(a,b)[-1], c), d, e, f]
  166. #4 [T(a,b)[:-1], T(T(a,b)[-1], c)[:-1]] [T(T(T(a,b)[-1], c)[-1], d), e, f]
  167. #5 ...
  168. One limitation of the current implementation is that we assume that only the
  169. last item of the transformed tuple goes back into the args to be transformed
  170. again. These seems to handle the cases needed for Mul. However, we may need
  171. to extend the algorithm to have the entire tuple go back into the args for
  172. further transformation.
  173. """
  174. args = list(expr.args)
  175. result = []
  176. # Continue as long as we have at least 2 elements
  177. while len(args) > 1:
  178. # Get first two elements
  179. first = args.pop(0)
  180. second = args[0] # Look at second element without popping yet
  181. transformed = _transform_state_pair(first, second)
  182. if transformed is None:
  183. # If transform returns None, append first element
  184. result.append(first)
  185. else:
  186. # This item was transformed, pop and discard
  187. args.pop(0)
  188. # The last item goes back to be transformed again
  189. args.insert(0, transformed[-1])
  190. # All other items go directly into the result
  191. result.extend(transformed[:-1])
  192. # Append any remaining element
  193. if args:
  194. result.append(args[0])
  195. return Mul._from_args(result, is_commutative=False)
  196. def _postprocess_state_pow(expr):
  197. """Handle bras and kets raised to powers.
  198. Under ``*`` multiplication this is invalid. Users should use a
  199. TensorProduct instead.
  200. """
  201. base, exp = expr.as_base_exp()
  202. if base.kind == KetKind or base.kind == BraKind:
  203. raise TypeError(
  204. 'A bra or ket to a power is invalid, use TensorProduct instead.'
  205. )
  206. def _postprocess_tp_pow(expr):
  207. """Handle TensorProduct(*operators)**(positive integer).
  208. This handles a tensor product of operators, to an integer power.
  209. The power here is interpreted as regular multiplication, not
  210. tensor product exponentiation. The form of exponentiation performed
  211. here leaves the space and dimension of the object the same.
  212. This operation does not make sense for tensor product's of states.
  213. """
  214. base, exp = expr.as_base_exp()
  215. debug('_postprocess_tp_pow: ', base, exp, expr.args)
  216. if isinstance(base, TensorProduct) and exp.is_integer and exp.is_positive and base.kind == OperatorKind:
  217. new_args = [a**exp for a in base.args]
  218. return TensorProduct(*new_args)
  219. #-----------------------------------------------------------------------------
  220. # Register the transformers with Basic._constructor_postprocessor_mapping
  221. #-----------------------------------------------------------------------------
  222. Basic._constructor_postprocessor_mapping[StateBase] = {
  223. "Mul": [_postprocess_state_mul],
  224. "Pow": [_postprocess_state_pow]
  225. }
  226. Basic._constructor_postprocessor_mapping[TensorProduct] = {
  227. "Mul": [_postprocess_state_mul],
  228. "Pow": [_postprocess_tp_pow]
  229. }
  230. Basic._constructor_postprocessor_mapping[Operator] = {
  231. "Mul": [_postprocess_state_mul]
  232. }