_tensor_str.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import dataclasses
  4. import math
  5. import textwrap
  6. from typing import Any
  7. import torch
  8. from torch import inf
  9. @dataclasses.dataclass
  10. class __PrinterOptions:
  11. precision: int = 4
  12. threshold: float = 1000
  13. edgeitems: int = 3
  14. linewidth: int = 80
  15. sci_mode: bool | None = None
  16. PRINT_OPTS = __PrinterOptions()
  17. # We could use **kwargs, but this will give better docs
  18. def set_printoptions(
  19. precision=None,
  20. threshold=None,
  21. edgeitems=None,
  22. linewidth=None,
  23. profile=None,
  24. sci_mode=None,
  25. ):
  26. r"""Set options for printing. Items shamelessly taken from NumPy
  27. Args:
  28. precision: Number of digits of precision for floating point output
  29. (default = 4).
  30. threshold: Total number of array elements which trigger summarization
  31. rather than full `repr` (default = 1000).
  32. edgeitems: Number of array items in summary at beginning and end of
  33. each dimension (default = 3).
  34. linewidth: The number of characters per line for the purpose of
  35. inserting line breaks (default = 80). Thresholded matrices will
  36. ignore this parameter.
  37. profile: Sane defaults for pretty printing. Can override with any of
  38. the above options. (any one of `default`, `short`, `full`)
  39. sci_mode: Enable (True) or disable (False) scientific notation. If
  40. None (default) is specified, the value is defined by
  41. `torch._tensor_str._Formatter`. This value is automatically chosen
  42. by the framework.
  43. Example::
  44. >>> # Limit the precision of elements
  45. >>> torch.set_printoptions(precision=2)
  46. >>> torch.tensor([1.12345])
  47. tensor([1.12])
  48. >>> # Limit the number of elements shown
  49. >>> torch.set_printoptions(threshold=5)
  50. >>> torch.arange(10)
  51. tensor([0, 1, 2, ..., 7, 8, 9])
  52. >>> # Restore defaults
  53. >>> torch.set_printoptions(profile='default')
  54. >>> torch.tensor([1.12345])
  55. tensor([1.1235])
  56. >>> torch.arange(10)
  57. tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
  58. """
  59. if profile is not None:
  60. if profile == "default":
  61. PRINT_OPTS.precision = 4
  62. PRINT_OPTS.threshold = 1000
  63. PRINT_OPTS.edgeitems = 3
  64. PRINT_OPTS.linewidth = 80
  65. elif profile == "short":
  66. PRINT_OPTS.precision = 2
  67. PRINT_OPTS.threshold = 1000
  68. PRINT_OPTS.edgeitems = 2
  69. PRINT_OPTS.linewidth = 80
  70. elif profile == "full":
  71. PRINT_OPTS.precision = 4
  72. PRINT_OPTS.threshold = inf
  73. PRINT_OPTS.edgeitems = 3
  74. PRINT_OPTS.linewidth = 80
  75. if precision is not None:
  76. PRINT_OPTS.precision = precision
  77. if threshold is not None:
  78. PRINT_OPTS.threshold = threshold
  79. if edgeitems is not None:
  80. PRINT_OPTS.edgeitems = edgeitems
  81. if linewidth is not None:
  82. PRINT_OPTS.linewidth = linewidth
  83. PRINT_OPTS.sci_mode = sci_mode
  84. def get_printoptions() -> dict[str, Any]:
  85. r"""Gets the current options for printing, as a dictionary that
  86. can be passed as ``**kwargs`` to set_printoptions().
  87. """
  88. return dataclasses.asdict(PRINT_OPTS)
  89. @contextlib.contextmanager
  90. def printoptions(**kwargs):
  91. r"""Context manager that temporarily changes the print options. Accepted
  92. arguments are same as :func:`set_printoptions`."""
  93. old_kwargs = get_printoptions()
  94. set_printoptions(**kwargs)
  95. try:
  96. yield
  97. finally:
  98. set_printoptions(**old_kwargs)
  99. def tensor_totype(t):
  100. dtype = (
  101. torch.float
  102. if (
  103. t.is_mps
  104. or (t.is_xpu and not torch.xpu.get_device_properties(t.device).has_fp64)
  105. or t.is_maia
  106. )
  107. else torch.double
  108. )
  109. return t.to(dtype=dtype)
  110. class _Formatter:
  111. def __init__(self, tensor):
  112. self.floating_dtype = tensor.dtype.is_floating_point
  113. self.int_mode = True
  114. self.sci_mode = False
  115. self.max_width = 1
  116. with torch.no_grad():
  117. tensor_view = tensor.reshape(-1)
  118. if not self.floating_dtype:
  119. for value in tensor_view:
  120. value_str = f"{value}"
  121. self.max_width = max(self.max_width, len(value_str))
  122. else:
  123. if tensor.dtype == torch.float4_e2m1fn_x2: # type: ignore[attr-defined]
  124. # torch.float4_e2m1fn_x2 is special and does not support the casts necessary
  125. # to print it, we choose to display the uint8 representation here for
  126. # convenience of being able to print a tensor.
  127. # TODO(#146647): extend this to other dtypes without casts defined, such
  128. # as the bits, uint1..7 and int1..7 dtypes.
  129. tensor_view = tensor_view.view(torch.uint8)
  130. nonzero_finite_vals = torch.masked_select(
  131. tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0)
  132. )
  133. if nonzero_finite_vals.numel() == 0:
  134. # no valid number, do nothing
  135. return
  136. if tensor.dtype == torch.float8_e8m0fnu: # type: ignore[attr-defined]
  137. # float8_e8m0fnu is special and does not define arithmetic ops,
  138. # and printing code further in this file assumes the existence
  139. # of various arithmetic ops to figure out what to print. We hack
  140. # and convert to float here to make printing work correctly.
  141. # TODO(#113663): also add the other float8 dtypes here after arithmetic
  142. # support for them is removed
  143. nonzero_finite_vals = nonzero_finite_vals.float()
  144. # Convert to double (or float) for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
  145. nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs())
  146. nonzero_finite_min = tensor_totype(nonzero_finite_abs.min())
  147. nonzero_finite_max = tensor_totype(nonzero_finite_abs.max())
  148. for value in nonzero_finite_vals:
  149. if value != torch.ceil(value):
  150. self.int_mode = False
  151. break
  152. self.sci_mode = (
  153. nonzero_finite_max / nonzero_finite_min > 1000.0
  154. or nonzero_finite_max > 1.0e8
  155. or nonzero_finite_min < 1.0e-4
  156. if PRINT_OPTS.sci_mode is None
  157. else PRINT_OPTS.sci_mode
  158. )
  159. if self.int_mode:
  160. # in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
  161. # to indicate that the tensor is of floating type. add 1 to the len to account for this.
  162. if self.sci_mode:
  163. for value in nonzero_finite_vals:
  164. value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value)
  165. self.max_width = max(self.max_width, len(value_str))
  166. else:
  167. for value in nonzero_finite_vals:
  168. value_str = f"{value:.0f}"
  169. self.max_width = max(self.max_width, len(value_str) + 1)
  170. else:
  171. # Check if scientific representation should be used.
  172. if self.sci_mode:
  173. for value in nonzero_finite_vals:
  174. value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value)
  175. self.max_width = max(self.max_width, len(value_str))
  176. else:
  177. for value in nonzero_finite_vals:
  178. value_str = f"{{:.{PRINT_OPTS.precision}f}}".format(value)
  179. self.max_width = max(self.max_width, len(value_str))
  180. def width(self):
  181. return self.max_width
  182. def format(self, value):
  183. if self.floating_dtype:
  184. if self.sci_mode:
  185. ret = f"{{:{self.max_width}.{PRINT_OPTS.precision}e}}".format(value)
  186. elif self.int_mode:
  187. ret = f"{value:.0f}"
  188. if not (math.isinf(value) or math.isnan(value)):
  189. ret += "."
  190. else:
  191. ret = f"{{:.{PRINT_OPTS.precision}f}}".format(value)
  192. else:
  193. ret = f"{value}"
  194. return (self.max_width - len(ret)) * " " + ret
  195. def _scalar_str(self, formatter1, formatter2=None):
  196. if formatter2 is not None:
  197. real_str = _scalar_str(self.real, formatter1)
  198. imag_str = (_scalar_str(self.imag, formatter2) + "j").lstrip()
  199. # handles negative numbers, +0.0, -0.0
  200. if imag_str[0] == "+" or imag_str[0] == "-":
  201. return real_str + imag_str
  202. else:
  203. return real_str + "+" + imag_str
  204. else:
  205. return formatter1.format(self.item())
  206. def _vector_str(self, indent, summarize, formatter1, formatter2=None):
  207. # length includes spaces and comma between elements
  208. element_length = formatter1.width() + 2
  209. if formatter2 is not None:
  210. # width for imag_formatter + an extra j for complex
  211. element_length += formatter2.width() + 1
  212. elements_per_line = max(
  213. 1, math.floor((PRINT_OPTS.linewidth - indent) / (element_length))
  214. )
  215. def _val_formatter(val, formatter1=formatter1, formatter2=formatter2):
  216. if formatter2 is not None:
  217. real_str = formatter1.format(val.real)
  218. imag_str = (formatter2.format(val.imag) + "j").lstrip()
  219. # handles negative numbers, +0.0, -0.0
  220. if imag_str[0] == "+" or imag_str[0] == "-":
  221. return real_str + imag_str
  222. else:
  223. return real_str + "+" + imag_str
  224. else:
  225. return formatter1.format(val)
  226. if self.dtype == torch.float4_e2m1fn_x2: # type: ignore[attr-defined]
  227. # torch.float4_e2m1fn_x2 is special and does not support the casts necessary
  228. # to print it, we choose to display the uint8 representation here for
  229. # convenience of being able to print a tensor.
  230. # TODO(#146647): extend this to other dtypes without casts defined, such
  231. # as the bits, uint1..7 and int1..7 dtypes.
  232. self = self.view(torch.uint8)
  233. if summarize and not PRINT_OPTS.edgeitems:
  234. # Deal with edge case that negative zero is zero
  235. data = ["..."]
  236. elif summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
  237. data = (
  238. [_val_formatter(val) for val in self[: PRINT_OPTS.edgeitems].tolist()]
  239. + [" ..."]
  240. + [_val_formatter(val) for val in self[-PRINT_OPTS.edgeitems :].tolist()]
  241. )
  242. else:
  243. data = [_val_formatter(val) for val in self.tolist()]
  244. data_lines = [
  245. data[i : i + elements_per_line] for i in range(0, len(data), elements_per_line)
  246. ]
  247. lines = [", ".join(line) for line in data_lines]
  248. return "[" + ("," + "\n" + " " * (indent + 1)).join(lines) + "]"
  249. # formatter2 is only used for printing complex tensors.
  250. # For complex tensors, formatter1 and formatter2 are the formatters for tensor.real
  251. # and tensor.imag respesectively
  252. def _tensor_str_with_formatter(self, indent, summarize, formatter1, formatter2=None):
  253. dim = self.dim()
  254. if dim == 0:
  255. return _scalar_str(self, formatter1, formatter2)
  256. if dim == 1:
  257. return _vector_str(self, indent, summarize, formatter1, formatter2)
  258. if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
  259. slices = (
  260. [
  261. _tensor_str_with_formatter(
  262. self[i], indent + 1, summarize, formatter1, formatter2
  263. )
  264. for i in range(PRINT_OPTS.edgeitems)
  265. ]
  266. + ["..."]
  267. + [
  268. _tensor_str_with_formatter(
  269. self[i], indent + 1, summarize, formatter1, formatter2
  270. )
  271. for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))
  272. ]
  273. )
  274. else:
  275. slices = [
  276. _tensor_str_with_formatter(
  277. self[i], indent + 1, summarize, formatter1, formatter2
  278. )
  279. for i in range(self.size(0))
  280. ]
  281. tensor_str = ("," + "\n" * (dim - 1) + " " * (indent + 1)).join(slices)
  282. return "[" + tensor_str + "]"
  283. def _tensor_str(self, indent):
  284. if self.numel() == 0:
  285. return "[]"
  286. if self.has_names():
  287. # There are two main codepaths (possibly more) that tensor printing goes through:
  288. # - tensor data can fit comfortably on screen
  289. # - tensor data needs to be summarized
  290. # Some of the codepaths don't fully support named tensors, so we send in
  291. # an unnamed tensor to the formatting code as a workaround.
  292. self = self.rename(None)
  293. summarize = self.numel() > PRINT_OPTS.threshold
  294. if self._is_zerotensor():
  295. self = self.clone()
  296. # handle the negative bit
  297. if self.is_neg():
  298. self = self.resolve_neg()
  299. # TODO: Remove me when `masked_select` is implemented for FP8
  300. if self.dtype in [
  301. torch.float8_e5m2,
  302. torch.float8_e5m2fnuz,
  303. torch.float8_e4m3fn,
  304. torch.float8_e4m3fnuz,
  305. ]:
  306. self = self.half()
  307. if self.dtype.is_complex:
  308. # handle the conjugate bit
  309. self = self.resolve_conj()
  310. real_formatter = _Formatter(
  311. get_summarized_data(self.real) if summarize else self.real
  312. )
  313. imag_formatter = _Formatter(
  314. get_summarized_data(self.imag) if summarize else self.imag
  315. )
  316. return _tensor_str_with_formatter(
  317. self, indent, summarize, real_formatter, imag_formatter
  318. )
  319. else:
  320. formatter = _Formatter(get_summarized_data(self) if summarize else self)
  321. return _tensor_str_with_formatter(self, indent, summarize, formatter)
  322. def _add_suffixes(tensor_str, suffixes, indent, force_newline):
  323. tensor_strs = [tensor_str]
  324. last_line_len = len(tensor_str) - tensor_str.rfind("\n") + 1
  325. for suffix in suffixes:
  326. suffix_len = len(suffix)
  327. if force_newline or last_line_len + suffix_len + 2 > PRINT_OPTS.linewidth:
  328. tensor_strs.append(",\n" + " " * indent + suffix)
  329. last_line_len = indent + suffix_len
  330. force_newline = False
  331. else:
  332. tensor_strs.append(", " + suffix)
  333. last_line_len += suffix_len + 2
  334. tensor_strs.append(")")
  335. return "".join(tensor_strs)
  336. def get_summarized_data(self):
  337. dim = self.dim()
  338. if dim == 0:
  339. return self
  340. if dim == 1:
  341. if self.size(0) > 2 * PRINT_OPTS.edgeitems:
  342. return torch.cat(
  343. (self[: PRINT_OPTS.edgeitems], self[-PRINT_OPTS.edgeitems :])
  344. )
  345. else:
  346. return self
  347. if not PRINT_OPTS.edgeitems:
  348. return self.new_empty([0] * self.dim())
  349. elif self.size(0) > 2 * PRINT_OPTS.edgeitems:
  350. start = [self[i] for i in range(PRINT_OPTS.edgeitems)]
  351. end = [self[i] for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))]
  352. return torch.stack([get_summarized_data(x) for x in (start + end)])
  353. else:
  354. return torch.stack([get_summarized_data(x) for x in self])
  355. def _str_intern(inp, *, tensor_contents=None):
  356. if torch._C._functorch.is_functorch_wrapped_tensor(inp):
  357. return _functorch_wrapper_str_intern(inp, tensor_contents=tensor_contents)
  358. is_plain_tensor = type(inp) is torch.Tensor or type(inp) is torch.nn.Parameter
  359. if inp.is_nested:
  360. prefix = "nested_tensor("
  361. elif is_plain_tensor:
  362. prefix = "tensor("
  363. else:
  364. prefix = f"{type(inp).__name__}("
  365. indent = len(prefix)
  366. suffixes = []
  367. custom_contents_provided = tensor_contents is not None
  368. if custom_contents_provided:
  369. tensor_str = tensor_contents
  370. # This is used to extract the primal value and thus disable the forward AD
  371. # within this function.
  372. # TODO(albanD) This needs to be updated when more than one level is supported
  373. self, tangent = torch.autograd.forward_ad.unpack_dual(inp)
  374. # Note [Print tensor device]:
  375. # A general logic here is we only print device when it doesn't match
  376. # the device specified in default tensor type.
  377. # Currently torch.set_default_tensor_type() only supports CPU/CUDA, thus
  378. # torch._C._get_default_device() only returns either cpu or cuda.
  379. # In other cases, we don't have a way to set them as default yet,
  380. # and we should always print out device for them.
  381. if (
  382. self.device.type != torch._C._get_default_device()
  383. or (
  384. self.device.type == "cuda"
  385. and torch.cuda.current_device() != self.device.index
  386. )
  387. or (self.device.type == "mps")
  388. ):
  389. suffixes.append("device='" + str(self.device) + "'")
  390. # Tensor printing performs tensor operations like slice, indexing, etc to make it in a
  391. # representable format. These operations on ipu/xla/lazy/mtia tensor results in compilations. Hence,
  392. # to avoid compilations, copying the tensor to cpu before printing.
  393. if self.device.type in ["xla", "lazy", "ipu", "mtia"]:
  394. self = self.to("cpu")
  395. # TODO: add an API to map real -> complex dtypes
  396. _default_complex_dtype = (
  397. torch.cdouble if torch.get_default_dtype() == torch.double else torch.cfloat
  398. )
  399. has_default_dtype = self.dtype in (
  400. torch.get_default_dtype(),
  401. _default_complex_dtype,
  402. torch.int64,
  403. torch.bool,
  404. )
  405. if self.is_sparse:
  406. suffixes.append("size=" + str(tuple(self.shape)))
  407. from torch._subclasses.fake_tensor import FakeTensor
  408. is_meta = self.is_meta or isinstance(self, FakeTensor)
  409. if not is_meta:
  410. suffixes.append("nnz=" + str(self._nnz()))
  411. if not has_default_dtype:
  412. suffixes.append("dtype=" + str(self.dtype))
  413. if not custom_contents_provided:
  414. indices_prefix = "indices=tensor("
  415. indices = self._indices().detach()
  416. if is_meta:
  417. indices_str = "..."
  418. else:
  419. indices_str = _tensor_str(indices, indent + len(indices_prefix))
  420. if is_meta or indices.numel() == 0:
  421. indices_str += ", size=" + str(tuple(indices.shape))
  422. values_prefix = "values=tensor("
  423. values = self._values().detach()
  424. if is_meta:
  425. values_str = "..."
  426. else:
  427. values_str = _tensor_str(values, indent + len(values_prefix))
  428. if is_meta or values.numel() == 0:
  429. values_str += ", size=" + str(tuple(values.shape))
  430. tensor_str = (
  431. indices_prefix
  432. + indices_str
  433. + "),\n"
  434. + " " * indent
  435. + values_prefix
  436. + values_str
  437. + ")"
  438. )
  439. elif self.layout in {
  440. torch.sparse_csr,
  441. torch.sparse_csc,
  442. torch.sparse_bsr,
  443. torch.sparse_bsc,
  444. }:
  445. from torch._subclasses.fake_tensor import FakeTensor
  446. suffixes.append("size=" + str(tuple(self.shape)))
  447. is_meta = self.is_meta or isinstance(self, FakeTensor)
  448. if not is_meta:
  449. suffixes.append("nnz=" + str(self._nnz()))
  450. if not has_default_dtype:
  451. suffixes.append("dtype=" + str(self.dtype))
  452. if not custom_contents_provided:
  453. compressed_indices_method, plain_indices_method = {
  454. torch.sparse_csr: (torch.Tensor.crow_indices, torch.Tensor.col_indices),
  455. torch.sparse_csc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices),
  456. torch.sparse_bsr: (torch.Tensor.crow_indices, torch.Tensor.col_indices),
  457. torch.sparse_bsc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices),
  458. }[self.layout]
  459. if self.layout in {torch.sparse_csr, torch.sparse_bsr}:
  460. cdimname, pdimname = "row", "column"
  461. else:
  462. cdimname, pdimname = "column", "row"
  463. compressed_indices_prefix = f"c{cdimname[:3]}_indices=tensor("
  464. compressed_indices = compressed_indices_method(self).detach()
  465. if is_meta:
  466. compressed_indices_str = "..."
  467. else:
  468. compressed_indices_str = _tensor_str(
  469. compressed_indices, indent + len(compressed_indices_prefix)
  470. )
  471. if compressed_indices.numel() == 0 or is_meta:
  472. compressed_indices_str += ", size=" + str(
  473. tuple(compressed_indices.shape)
  474. )
  475. plain_indices_prefix = f"{pdimname[:3]}_indices=tensor("
  476. plain_indices = plain_indices_method(self).detach()
  477. if is_meta:
  478. plain_indices_str = "..."
  479. else:
  480. plain_indices_str = _tensor_str(
  481. plain_indices, indent + len(plain_indices_prefix)
  482. )
  483. if plain_indices.numel() == 0 or is_meta:
  484. plain_indices_str += ", size=" + str(tuple(plain_indices.shape))
  485. values_prefix = "values=tensor("
  486. values = self.values().detach()
  487. if is_meta:
  488. values_str = "..."
  489. else:
  490. values_str = _tensor_str(values, indent + len(values_prefix))
  491. if values.numel() == 0 or is_meta:
  492. values_str += ", size=" + str(tuple(values.shape))
  493. tensor_str = (
  494. compressed_indices_prefix
  495. + compressed_indices_str
  496. + "),\n"
  497. + " " * indent
  498. + plain_indices_prefix
  499. + plain_indices_str
  500. + "),\n"
  501. + " " * indent
  502. + values_prefix
  503. + values_str
  504. + ")"
  505. )
  506. elif self.is_quantized:
  507. suffixes.append("size=" + str(tuple(self.shape)))
  508. if not has_default_dtype:
  509. suffixes.append("dtype=" + str(self.dtype))
  510. suffixes.append("quantization_scheme=" + str(self.qscheme()))
  511. if (
  512. self.qscheme() == torch.per_tensor_affine
  513. or self.qscheme() == torch.per_tensor_symmetric
  514. ):
  515. suffixes.append("scale=" + str(self.q_scale()))
  516. suffixes.append("zero_point=" + str(self.q_zero_point()))
  517. elif (
  518. self.qscheme() == torch.per_channel_affine
  519. or self.qscheme() == torch.per_channel_symmetric
  520. or self.qscheme() == torch.per_channel_affine_float_qparams
  521. ):
  522. suffixes.append("scale=" + str(self.q_per_channel_scales()))
  523. suffixes.append("zero_point=" + str(self.q_per_channel_zero_points()))
  524. suffixes.append("axis=" + str(self.q_per_channel_axis()))
  525. if not custom_contents_provided:
  526. tensor_str = _tensor_str(self.dequantize(), indent)
  527. elif self.is_nested:
  528. if not custom_contents_provided:
  529. def indented_str(s, indent):
  530. return "\n".join(f" {line}" for line in s.split("\n"))
  531. strs = ",\n".join(
  532. indented_str(str(t), indent + 1)
  533. for t in torch.ops.aten.unbind.int(self, 0)
  534. )
  535. tensor_str = f"[\n{strs}\n]"
  536. elif torch._is_functional_tensor(self):
  537. prefix = "_to_functional_tensor("
  538. tensor_str = repr(torch._from_functional_tensor(self))
  539. else:
  540. # Circular import problem, so we import it here
  541. from torch._subclasses.fake_tensor import FakeTensor
  542. if self.is_meta or isinstance(self, FakeTensor):
  543. suffixes.append("size=" + str(tuple(self.shape)))
  544. if self.dtype != torch.get_default_dtype():
  545. suffixes.append("dtype=" + str(self.dtype))
  546. # TODO: This implies that ellipses is valid syntax for allocating
  547. # a meta tensor or FakeTensor, which it could be, but it isn't right now
  548. if not custom_contents_provided:
  549. tensor_str = "..."
  550. else:
  551. if self.numel() == 0 and not self.is_sparse:
  552. # Explicitly print the shape if it is not (0,), to match NumPy behavior
  553. if self.dim() != 1:
  554. suffixes.append("size=" + str(tuple(self.shape)))
  555. # In an empty tensor, there are no elements to infer if the dtype
  556. # should be int64, so it must be shown explicitly.
  557. if self.dtype != torch.get_default_dtype():
  558. suffixes.append("dtype=" + str(self.dtype))
  559. if not custom_contents_provided:
  560. tensor_str = "[]"
  561. else:
  562. if not PRINT_OPTS.edgeitems:
  563. suffixes.append("size=" + str(tuple(self.shape)))
  564. if not has_default_dtype:
  565. suffixes.append("dtype=" + str(self.dtype))
  566. if not custom_contents_provided:
  567. if self.layout != torch.strided:
  568. tensor_str = _tensor_str(self.to_dense(), indent)
  569. else:
  570. tensor_str = _tensor_str(self, indent)
  571. if self.layout != torch.strided:
  572. suffixes.append("layout=" + str(self.layout))
  573. # Use inp here to get the original grad_fn and not the one generated by the forward grad
  574. # unpacking.
  575. grad_fn_name = None
  576. try:
  577. grad_fn = inp.grad_fn
  578. except RuntimeError:
  579. # Accessing the grad_fn calls rebasing logic which would cause an error
  580. # if that tensor is a view created in no-grad mode modified in-place in
  581. # no-grad mode. See: https://github.com/pytorch/pytorch/issues/99968
  582. grad_fn_name = "Invalid"
  583. if grad_fn_name is None and grad_fn is not None: # type: ignore[possibly-undefined]
  584. # pyrefly: ignore [unbound-name]
  585. grad_fn_name = type(grad_fn).__name__
  586. if grad_fn_name == "CppFunction":
  587. # pyrefly: ignore [unbound-name]
  588. grad_fn_name = grad_fn.name().rsplit("::", 1)[-1]
  589. if grad_fn_name is not None:
  590. suffixes.append(f"grad_fn=<{grad_fn_name}>")
  591. elif inp.requires_grad:
  592. suffixes.append("requires_grad=True")
  593. if self.has_names():
  594. suffixes.append(f"names={self.names}")
  595. if tangent is not None:
  596. suffixes.append(f"tangent={tangent}")
  597. string_repr = _add_suffixes(
  598. prefix + tensor_str, # type: ignore[possibly-undefined]
  599. suffixes,
  600. indent,
  601. force_newline=self.is_sparse,
  602. )
  603. # Check if this instance is flagged as a parameter and change the repr accordingly.
  604. # Unfortunately, this function has to be aware of this detail.
  605. # NB: This is currently skipped for plain tensor parameters to maintain BC. In the future,
  606. # this should be done for those as well to produce a valid repr.
  607. if isinstance(self, torch.nn.Parameter) and not is_plain_tensor:
  608. string_repr = f"Parameter({string_repr})"
  609. return string_repr
  610. def _functorch_wrapper_str_intern(tensor, *, tensor_contents=None):
  611. level = torch._C._functorch.maybe_get_level(tensor)
  612. if level == -1:
  613. raise AssertionError("expected functorch level to be >= 0, got -1")
  614. if torch._C._functorch.is_functionaltensor(tensor):
  615. # Since we're unwrapping the FunctionalTensorWrapper, we need to make sure
  616. # that it's up to date first
  617. torch._sync(tensor)
  618. value = torch._C._functorch.get_unwrapped(tensor)
  619. value_repr = repr(value)
  620. indented_value_repr = textwrap.indent(value_repr, " " * 4)
  621. if torch._C._functorch.is_batchedtensor(tensor):
  622. bdim = torch._C._functorch.maybe_get_bdim(tensor)
  623. if bdim == -1:
  624. raise AssertionError("expected batch dimension to be >= 0, got -1")
  625. return (
  626. f"BatchedTensor(lvl={level}, bdim={bdim}, value=\n{indented_value_repr}\n)"
  627. )
  628. if torch._C._functorch.is_gradtrackingtensor(tensor):
  629. return f"GradTrackingTensor(lvl={level}, value=\n{indented_value_repr}\n)"
  630. if torch._C._functorch.is_functionaltensor(tensor):
  631. return f"FunctionalTensor(lvl={level}, value=\\\n{value_repr})"
  632. raise ValueError("We don't know how to print this, please file us an issue")
  633. def _str(self, *, tensor_contents=None):
  634. with torch.no_grad(), torch.utils._python_dispatch._disable_current_modes():
  635. guard = torch._C._DisableFuncTorch() # noqa: F841
  636. return _str_intern(self, tensor_contents=tensor_contents)