pytorch.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. from sympy.printing.pycode import AbstractPythonCodePrinter, ArrayPrinter
  2. from sympy.matrices.expressions import MatrixExpr
  3. from sympy.core.mul import Mul
  4. from sympy.printing.precedence import PRECEDENCE
  5. from sympy.external import import_module
  6. from sympy.codegen.cfunctions import Sqrt
  7. from sympy import S
  8. from sympy import Integer
  9. import sympy
  10. torch = import_module('torch')
  11. class TorchPrinter(ArrayPrinter, AbstractPythonCodePrinter):
  12. printmethod = "_torchcode"
  13. mapping = {
  14. sympy.Abs: "torch.abs",
  15. sympy.sign: "torch.sign",
  16. # XXX May raise error for ints.
  17. sympy.ceiling: "torch.ceil",
  18. sympy.floor: "torch.floor",
  19. sympy.log: "torch.log",
  20. sympy.exp: "torch.exp",
  21. Sqrt: "torch.sqrt",
  22. sympy.cos: "torch.cos",
  23. sympy.acos: "torch.acos",
  24. sympy.sin: "torch.sin",
  25. sympy.asin: "torch.asin",
  26. sympy.tan: "torch.tan",
  27. sympy.atan: "torch.atan",
  28. sympy.atan2: "torch.atan2",
  29. # XXX Also may give NaN for complex results.
  30. sympy.cosh: "torch.cosh",
  31. sympy.acosh: "torch.acosh",
  32. sympy.sinh: "torch.sinh",
  33. sympy.asinh: "torch.asinh",
  34. sympy.tanh: "torch.tanh",
  35. sympy.atanh: "torch.atanh",
  36. sympy.Pow: "torch.pow",
  37. sympy.re: "torch.real",
  38. sympy.im: "torch.imag",
  39. sympy.arg: "torch.angle",
  40. # XXX May raise error for ints and complexes
  41. sympy.erf: "torch.erf",
  42. sympy.loggamma: "torch.lgamma",
  43. sympy.Eq: "torch.eq",
  44. sympy.Ne: "torch.ne",
  45. sympy.StrictGreaterThan: "torch.gt",
  46. sympy.StrictLessThan: "torch.lt",
  47. sympy.LessThan: "torch.le",
  48. sympy.GreaterThan: "torch.ge",
  49. sympy.And: "torch.logical_and",
  50. sympy.Or: "torch.logical_or",
  51. sympy.Not: "torch.logical_not",
  52. sympy.Max: "torch.max",
  53. sympy.Min: "torch.min",
  54. # Matrices
  55. sympy.MatAdd: "torch.add",
  56. sympy.HadamardProduct: "torch.mul",
  57. sympy.Trace: "torch.trace",
  58. # XXX May raise error for integer matrices.
  59. sympy.Determinant: "torch.det",
  60. }
  61. _default_settings = dict(
  62. AbstractPythonCodePrinter._default_settings,
  63. torch_version=None,
  64. requires_grad=False,
  65. dtype="torch.float64",
  66. )
  67. def __init__(self, settings=None):
  68. super().__init__(settings)
  69. version = self._settings['torch_version']
  70. self.requires_grad = self._settings['requires_grad']
  71. self.dtype = self._settings['dtype']
  72. if version is None and torch:
  73. version = torch.__version__
  74. self.torch_version = version
  75. def _print_Function(self, expr):
  76. op = self.mapping.get(type(expr), None)
  77. if op is None:
  78. return super()._print_Basic(expr)
  79. children = [self._print(arg) for arg in expr.args]
  80. if len(children) == 1:
  81. return "%s(%s)" % (
  82. self._module_format(op),
  83. children[0]
  84. )
  85. else:
  86. return self._expand_fold_binary_op(op, children)
  87. # mirrors the tensorflow version
  88. _print_Expr = _print_Function
  89. _print_Application = _print_Function
  90. _print_MatrixExpr = _print_Function
  91. _print_Relational = _print_Function
  92. _print_Not = _print_Function
  93. _print_And = _print_Function
  94. _print_Or = _print_Function
  95. _print_HadamardProduct = _print_Function
  96. _print_Trace = _print_Function
  97. _print_Determinant = _print_Function
  98. def _print_Inverse(self, expr):
  99. return '{}({})'.format(self._module_format("torch.linalg.inv"),
  100. self._print(expr.args[0]))
  101. def _print_Transpose(self, expr):
  102. if expr.arg.is_Matrix and expr.arg.shape[0] == expr.arg.shape[1]:
  103. # For square matrices, we can use the .t() method
  104. return "{}({}).t()".format("torch.transpose", self._print(expr.arg))
  105. else:
  106. # For non-square matrices or more general cases
  107. # transpose first and second dimensions (typical matrix transpose)
  108. return "{}.permute({})".format(
  109. self._print(expr.arg),
  110. ", ".join([str(i) for i in range(len(expr.arg.shape))])[::-1]
  111. )
  112. def _print_PermuteDims(self, expr):
  113. return "%s.permute(%s)" % (
  114. self._print(expr.expr),
  115. ", ".join(str(i) for i in expr.permutation.array_form)
  116. )
  117. def _print_Derivative(self, expr):
  118. # this version handles multi-variable and mixed partial derivatives. The tensorflow version does not.
  119. variables = expr.variables
  120. expr_arg = expr.expr
  121. # Handle multi-variable or repeated derivatives
  122. if len(variables) > 1 or (
  123. len(variables) == 1 and not isinstance(variables[0], tuple) and variables.count(variables[0]) > 1):
  124. result = self._print(expr_arg)
  125. var_groups = {}
  126. # Group variables by base symbol
  127. for var in variables:
  128. if isinstance(var, tuple):
  129. base_var, order = var
  130. var_groups[base_var] = var_groups.get(base_var, 0) + order
  131. else:
  132. var_groups[var] = var_groups.get(var, 0) + 1
  133. # Apply gradients in sequence
  134. for var, order in var_groups.items():
  135. for _ in range(order):
  136. result = "torch.autograd.grad({}, {}, create_graph=True)[0]".format(result, self._print(var))
  137. return result
  138. # Handle single variable case
  139. if len(variables) == 1:
  140. variable = variables[0]
  141. if isinstance(variable, tuple) and len(variable) == 2:
  142. base_var, order = variable
  143. if not isinstance(order, Integer): raise NotImplementedError("Only integer orders are supported")
  144. result = self._print(expr_arg)
  145. for _ in range(order):
  146. result = "torch.autograd.grad({}, {}, create_graph=True)[0]".format(result, self._print(base_var))
  147. return result
  148. return "torch.autograd.grad({}, {})[0]".format(self._print(expr_arg), self._print(variable))
  149. return self._print(expr_arg) # Empty variables case
  150. def _print_Piecewise(self, expr):
  151. from sympy import Piecewise
  152. e, cond = expr.args[0].args
  153. if len(expr.args) == 1:
  154. return '{}({}, {}, {})'.format(
  155. self._module_format("torch.where"),
  156. self._print(cond),
  157. self._print(e),
  158. 0)
  159. return '{}({}, {}, {})'.format(
  160. self._module_format("torch.where"),
  161. self._print(cond),
  162. self._print(e),
  163. self._print(Piecewise(*expr.args[1:])))
  164. def _print_Pow(self, expr):
  165. # XXX May raise error for
  166. # int**float or int**complex or float**complex
  167. base, exp = expr.args
  168. if expr.exp == S.Half:
  169. return "{}({})".format(
  170. self._module_format("torch.sqrt"), self._print(base))
  171. return "{}({}, {})".format(
  172. self._module_format("torch.pow"),
  173. self._print(base), self._print(exp))
  174. def _print_MatMul(self, expr):
  175. # Separate matrix and scalar arguments
  176. mat_args = [arg for arg in expr.args if isinstance(arg, MatrixExpr)]
  177. args = [arg for arg in expr.args if arg not in mat_args]
  178. # Handle scalar multipliers if present
  179. if args:
  180. return "%s*%s" % (
  181. self.parenthesize(Mul.fromiter(args), PRECEDENCE["Mul"]),
  182. self._expand_fold_binary_op("torch.matmul", mat_args)
  183. )
  184. else:
  185. return self._expand_fold_binary_op("torch.matmul", mat_args)
  186. def _print_MatPow(self, expr):
  187. return self._expand_fold_binary_op("torch.mm", [expr.base]*expr.exp)
  188. def _print_MatrixBase(self, expr):
  189. data = "[" + ", ".join(["[" + ", ".join([self._print(j) for j in i]) + "]" for i in expr.tolist()]) + "]"
  190. params = [str(data)]
  191. params.append(f"dtype={self.dtype}")
  192. if self.requires_grad:
  193. params.append("requires_grad=True")
  194. return "{}({})".format(
  195. self._module_format("torch.tensor"),
  196. ", ".join(params)
  197. )
  198. def _print_isnan(self, expr):
  199. return f'torch.isnan({self._print(expr.args[0])})'
  200. def _print_isinf(self, expr):
  201. return f'torch.isinf({self._print(expr.args[0])})'
  202. def _print_Identity(self, expr):
  203. if all(dim.is_Integer for dim in expr.shape):
  204. return "{}({})".format(
  205. self._module_format("torch.eye"),
  206. self._print(expr.shape[0])
  207. )
  208. else:
  209. # For symbolic dimensions, fall back to a more general approach
  210. return "{}({}, {})".format(
  211. self._module_format("torch.eye"),
  212. self._print(expr.shape[0]),
  213. self._print(expr.shape[1])
  214. )
  215. def _print_ZeroMatrix(self, expr):
  216. return "{}({})".format(
  217. self._module_format("torch.zeros"),
  218. self._print(expr.shape)
  219. )
  220. def _print_OneMatrix(self, expr):
  221. return "{}({})".format(
  222. self._module_format("torch.ones"),
  223. self._print(expr.shape)
  224. )
  225. def _print_conjugate(self, expr):
  226. return f"{self._module_format('torch.conj')}({self._print(expr.args[0])})"
  227. def _print_ImaginaryUnit(self, expr):
  228. return "1j" # uses the Python built-in 1j notation for the imaginary unit
  229. def _print_Heaviside(self, expr):
  230. args = [self._print(expr.args[0]), "0.5"]
  231. if len(expr.args) > 1:
  232. args[1] = self._print(expr.args[1])
  233. return f"{self._module_format('torch.heaviside')}({args[0]}, {args[1]})"
  234. def _print_gamma(self, expr):
  235. return f"{self._module_format('torch.special.gamma')}({self._print(expr.args[0])})"
  236. def _print_polygamma(self, expr):
  237. if expr.args[0] == S.Zero:
  238. return f"{self._module_format('torch.special.digamma')}({self._print(expr.args[1])})"
  239. else:
  240. raise NotImplementedError("PyTorch only supports digamma (0th order polygamma)")
  241. _module = "torch"
  242. _einsum = "einsum"
  243. _add = "add"
  244. _transpose = "t"
  245. _ones = "ones"
  246. _zeros = "zeros"
  247. def torch_code(expr, requires_grad=False, dtype="torch.float64", **settings):
  248. printer = TorchPrinter(settings={'requires_grad': requires_grad, 'dtype': dtype})
  249. return printer.doprint(expr, **settings)