_backends.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764
  1. """
  2. Backends in `einops` are organized to meet the following requirements
  3. - backends are not imported unless those are actually needed, because
  4. - backends may not be installed
  5. - importing all available backends will drive to significant memory footprint
  6. - backends may be present but installed with errors (but never used),
  7. importing may drive to crashes
  8. - backend should be either symbolic or imperative
  9. - this determines which methods (from_numpy/to_numpy or create_symbol/eval_symbol) should be defined
  10. - if backend can't provide symbols for shape dimensions, UnknownSize objects are used
  11. """
  12. import sys
  13. __author__ = "Alex Rogozhnikov"
  14. _loaded_backends: dict = {}
  15. _type2backend: dict = {}
  16. _debug_importing = False
  17. def get_backend(tensor) -> "AbstractBackend":
  18. """
  19. Takes a correct backend (e.g. numpy backend if tensor is numpy.ndarray) for a tensor.
  20. If needed, imports package and creates backend
  21. """
  22. _type = type(tensor)
  23. _result = _type2backend.get(_type, None)
  24. if _result is not None:
  25. return _result
  26. previously_loaded_backends = list(_loaded_backends.items())
  27. for _framework_name, backend in previously_loaded_backends:
  28. if backend.is_appropriate_type(tensor):
  29. _type2backend[_type] = backend
  30. return backend
  31. # Find backend subclasses recursively
  32. backend_subclasses = []
  33. backends = AbstractBackend.__subclasses__()
  34. while backends:
  35. backend = backends.pop()
  36. backends += backend.__subclasses__()
  37. backend_subclasses.append(backend)
  38. # handles modification of _loaded_backends from other thread, see #391
  39. prev_backend_names = [x for x, _ in previously_loaded_backends]
  40. for BackendSubclass in backend_subclasses:
  41. if _debug_importing:
  42. print("Testing for subclass of ", BackendSubclass)
  43. if BackendSubclass.framework_name not in prev_backend_names:
  44. # check that module was already imported. Otherwise it can't be imported
  45. if BackendSubclass.framework_name in sys.modules:
  46. if _debug_importing:
  47. print("Imported backend for ", BackendSubclass.framework_name)
  48. backend = BackendSubclass()
  49. _loaded_backends[backend.framework_name] = backend
  50. if backend.is_appropriate_type(tensor):
  51. _type2backend[_type] = backend
  52. return backend
  53. raise RuntimeError(f"Tensor type unknown to einops {type(tensor)}")
  54. class AbstractBackend:
  55. """Base backend class, major part of methods are only for debugging purposes."""
  56. framework_name: str
  57. def is_appropriate_type(self, tensor):
  58. """helper method should recognize tensors it can handle"""
  59. raise NotImplementedError()
  60. def from_numpy(self, x):
  61. raise NotImplementedError("framework doesn't support imperative execution")
  62. def to_numpy(self, x):
  63. raise NotImplementedError("framework doesn't support imperative execution")
  64. def create_symbol(self, shape):
  65. raise NotImplementedError("framework doesn't support symbolic computations")
  66. def eval_symbol(self, symbol, symbol_value_pairs):
  67. # symbol-value pairs is list[tuple[symbol, value-tensor]]
  68. raise NotImplementedError("framework doesn't support symbolic computations")
  69. def arange(self, start, stop):
  70. # supplementary method used only in testing, so should implement CPU version
  71. raise NotImplementedError("framework doesn't implement arange")
  72. def shape(self, x):
  73. """shape should return a tuple with integers or "shape symbols" (which will evaluate to actual size)"""
  74. return x.shape
  75. def reshape(self, x, shape):
  76. return x.reshape(shape)
  77. def transpose(self, x, axes):
  78. return x.transpose(axes)
  79. def reduce(self, x, operation, axes):
  80. return getattr(x, operation)(axis=axes)
  81. def stack_on_zeroth_dimension(self, tensors: list):
  82. raise NotImplementedError()
  83. def add_axis(self, x, new_position):
  84. raise NotImplementedError()
  85. def add_axes(self, x, n_axes, pos2len):
  86. repeats = [1] * n_axes
  87. for axis_position, axis_length in pos2len.items():
  88. x = self.add_axis(x, axis_position)
  89. repeats[axis_position] = axis_length
  90. return self.tile(x, tuple(repeats))
  91. def tile(self, x, repeats):
  92. """repeats - same lengths as x.shape"""
  93. raise NotImplementedError()
  94. def concat(self, tensors, axis: int):
  95. """concatenates tensors along axis.
  96. Assume identical across tensors: devices, dtypes and shapes except selected axis."""
  97. raise NotImplementedError()
  98. def is_float_type(self, x):
  99. # some backends (torch) can't compute average for non-floating types.
  100. # Decided to drop average for all backends if type is not floating
  101. raise NotImplementedError()
  102. def layers(self):
  103. raise NotImplementedError("backend does not provide layers")
  104. def __repr__(self):
  105. return f"<einops backend for {self.framework_name}>"
  106. def einsum(self, pattern, *x):
  107. raise NotImplementedError("backend does not support einsum")
  108. class UnknownSize:
  109. """pseudo-symbol for symbolic frameworks which do not provide symbols for shape elements"""
  110. def __floordiv__(self, other):
  111. return self
  112. def __eq__(self, other):
  113. return True # we don't know actual size
  114. def __mul__(self, other):
  115. return self
  116. def __rmul__(self, other):
  117. return self
  118. def __hash__(self):
  119. return hash(None)
  120. class NumpyBackend(AbstractBackend):
  121. framework_name = "numpy"
  122. def __init__(self):
  123. import numpy
  124. self.np = numpy
  125. def is_appropriate_type(self, tensor):
  126. return isinstance(tensor, self.np.ndarray)
  127. def from_numpy(self, x):
  128. return x
  129. def to_numpy(self, x):
  130. return x
  131. def arange(self, start, stop):
  132. return self.np.arange(start, stop)
  133. def stack_on_zeroth_dimension(self, tensors: list):
  134. return self.np.stack(tensors)
  135. def tile(self, x, repeats):
  136. return self.np.tile(x, repeats)
  137. def concat(self, tensors, axis: int):
  138. return self.np.concatenate(tensors, axis=axis)
  139. def is_float_type(self, x):
  140. return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16")
  141. def add_axis(self, x, new_position):
  142. return self.np.expand_dims(x, new_position)
  143. def einsum(self, pattern, *x):
  144. return self.np.einsum(pattern, *x)
  145. class JaxBackend(NumpyBackend):
  146. framework_name = "jax"
  147. def __init__(self):
  148. super().__init__()
  149. self.onp = self.np
  150. import jax.numpy
  151. self.np = jax.numpy
  152. def from_numpy(self, x):
  153. return self.np.asarray(x)
  154. def to_numpy(self, x):
  155. return self.onp.asarray(x)
  156. class TorchBackend(AbstractBackend):
  157. framework_name = "torch"
  158. def __init__(self):
  159. import torch
  160. self.torch = torch
  161. # importing would register operations in torch._dynamo for torch.compile
  162. from . import _torch_specific # noqa
  163. def is_appropriate_type(self, tensor):
  164. return isinstance(tensor, self.torch.Tensor)
  165. def from_numpy(self, x):
  166. variable = self.torch.from_numpy(x)
  167. if self.is_float_type(variable):
  168. # attach grad only to floating types
  169. variable.requires_grad = True
  170. return variable
  171. def to_numpy(self, x):
  172. return x.detach().cpu().numpy()
  173. def arange(self, start, stop):
  174. return self.torch.arange(start, stop, dtype=self.torch.int64)
  175. def reduce(self, x, operation, reduced_axes):
  176. if operation == "min":
  177. return x.amin(dim=reduced_axes)
  178. elif operation == "max":
  179. return x.amax(dim=reduced_axes)
  180. elif operation == "sum":
  181. return x.sum(dim=reduced_axes)
  182. elif operation == "mean":
  183. return x.mean(dim=reduced_axes)
  184. elif operation in ("any", "all", "prod"):
  185. # pytorch supports reducing only one operation at a time
  186. for i in sorted(reduced_axes)[::-1]:
  187. x = getattr(x, operation)(dim=i)
  188. return x
  189. else:
  190. raise NotImplementedError("Unknown reduction ", operation)
  191. def transpose(self, x, axes):
  192. return x.permute(axes)
  193. def stack_on_zeroth_dimension(self, tensors: list):
  194. return self.torch.stack(tensors)
  195. def add_axes(self, x, n_axes, pos2len):
  196. repeats = [-1] * n_axes
  197. for axis_position, axis_length in pos2len.items():
  198. x = self.add_axis(x, axis_position)
  199. repeats[axis_position] = axis_length
  200. return x.expand(repeats)
  201. def tile(self, x, repeats):
  202. return x.repeat(repeats)
  203. def concat(self, tensors, axis: int):
  204. return self.torch.cat(tensors, dim=axis)
  205. def add_axis(self, x, new_position):
  206. return self.torch.unsqueeze(x, new_position)
  207. def is_float_type(self, x):
  208. return x.dtype in [self.torch.float16, self.torch.float32, self.torch.float64, self.torch.bfloat16]
  209. def layers(self):
  210. from .layers import torch
  211. return torch
  212. def einsum(self, pattern, *x):
  213. return self.torch.einsum(pattern, *x)
  214. class CupyBackend(AbstractBackend):
  215. framework_name = "cupy"
  216. def __init__(self):
  217. import cupy
  218. self.cupy = cupy
  219. def is_appropriate_type(self, tensor):
  220. return isinstance(tensor, self.cupy.ndarray)
  221. def from_numpy(self, x):
  222. return self.cupy.asarray(x)
  223. def to_numpy(self, x):
  224. return self.cupy.asnumpy(x)
  225. def arange(self, start, stop):
  226. return self.cupy.arange(start, stop)
  227. def stack_on_zeroth_dimension(self, tensors: list):
  228. return self.cupy.stack(tensors)
  229. def tile(self, x, repeats):
  230. return self.cupy.tile(x, repeats)
  231. def concat(self, tensors, axis: int):
  232. return self.cupy.concatenate(tensors, axis=axis)
  233. def add_axis(self, x, new_position):
  234. return self.cupy.expand_dims(x, new_position)
  235. def is_float_type(self, x):
  236. return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16")
  237. def einsum(self, pattern, *x):
  238. return self.cupy.einsum(pattern, *x)
  239. class HashableTuple:
  240. """Overcomes non-hashability of symbolic elements"""
  241. def __init__(self, elements: tuple):
  242. self.elements = elements
  243. def __iter__(self):
  244. yield from self.elements
  245. def __len__(self):
  246. return len(self.elements)
  247. def __getitem__(self, item):
  248. return self.elements[item]
  249. # default equality and hash is used (True only with itself, hash taken of id)
  250. class TensorflowBackend(AbstractBackend):
  251. framework_name = "tensorflow"
  252. def __init__(self):
  253. import tensorflow
  254. self.tf = tensorflow
  255. def is_appropriate_type(self, tensor):
  256. return isinstance(tensor, (self.tf.Tensor, self.tf.Variable))
  257. def from_numpy(self, x):
  258. assert self.tf.executing_eagerly()
  259. return self.tf.convert_to_tensor(x)
  260. def to_numpy(self, x):
  261. assert self.tf.executing_eagerly()
  262. return x.numpy()
  263. def arange(self, start, stop):
  264. return self.tf.range(start, stop)
  265. def shape(self, x):
  266. if self.tf.executing_eagerly():
  267. return tuple(UnknownSize() if d is None else int(d) for d in x.shape)
  268. else:
  269. static_shape = x.shape.as_list()
  270. tf_shape = self.tf.shape(x)
  271. # use the static shape where known, otherwise use the TF shape components
  272. shape = tuple([s or tf_shape[dim] for dim, s in enumerate(static_shape)])
  273. try:
  274. hash(shape)
  275. return shape
  276. except BaseException:
  277. # unhashable symbols in shape. Wrap tuple to be hashable.
  278. return HashableTuple(shape)
  279. def reduce(self, x, operation, axes):
  280. return getattr(self.tf, "reduce_" + operation)(x, axis=axes)
  281. def reshape(self, x, shape):
  282. return self.tf.reshape(x, shape)
  283. def transpose(self, x, axes):
  284. return self.tf.transpose(x, axes)
  285. def stack_on_zeroth_dimension(self, tensors: list):
  286. return self.tf.stack(tensors)
  287. def tile(self, x, repeats):
  288. return self.tf.tile(x, repeats)
  289. def concat(self, tensors, axis: int):
  290. return self.tf.concat(tensors, axis=axis)
  291. def add_axis(self, x, new_position):
  292. return self.tf.expand_dims(x, new_position)
  293. def is_float_type(self, x):
  294. return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16")
  295. def layers(self):
  296. from .layers import tensorflow
  297. return tensorflow
  298. def einsum(self, pattern, *x):
  299. return self.tf.einsum(pattern, *x)
  300. class TFKerasBackend(AbstractBackend):
  301. framework_name = "tensorflow.keras"
  302. def __init__(self):
  303. import tensorflow as tf
  304. self.tf = tf
  305. self.keras = tf.keras
  306. self.K = tf.keras.backend
  307. def is_appropriate_type(self, tensor):
  308. return self.tf.is_tensor(tensor) and self.K.is_keras_tensor(tensor)
  309. def create_symbol(self, shape):
  310. return self.keras.Input(batch_shape=shape)
  311. def eval_symbol(self, symbol, symbol_value_pairs):
  312. model = self.keras.models.Model([var for (var, _) in symbol_value_pairs], symbol)
  313. return model.predict_on_batch([val for (_, val) in symbol_value_pairs])
  314. def arange(self, start, stop):
  315. return self.K.arange(start, stop)
  316. def shape(self, x):
  317. shape = self.K.shape(x) # tf tensor
  318. return HashableTuple(tuple(shape))
  319. def reduce(self, x, operation, axes):
  320. return getattr(self.K, operation)(x, axis=axes)
  321. def reshape(self, x, shape):
  322. return self.K.reshape(x, shape)
  323. def transpose(self, x, axes):
  324. return self.K.permute_dimensions(x, axes)
  325. def stack_on_zeroth_dimension(self, tensors: list):
  326. return self.K.stack(tensors)
  327. def tile(self, x, repeats):
  328. return self.K.tile(x, repeats)
  329. def concat(self, tensors, axis: int):
  330. return self.K.concatenate(tensors, axis=axis)
  331. def add_axis(self, x, new_position):
  332. return self.K.expand_dims(x, new_position)
  333. def is_float_type(self, x):
  334. return "float" in self.K.dtype(x)
  335. def layers(self):
  336. from .layers import keras
  337. return keras
  338. class OneFlowBackend(AbstractBackend):
  339. framework_name = "oneflow"
  340. def __init__(self):
  341. import oneflow as flow
  342. self.flow = flow
  343. def is_appropriate_type(self, tensor):
  344. return isinstance(tensor, self.flow.Tensor)
  345. def from_numpy(self, x):
  346. variable = self.flow.from_numpy(x)
  347. if self.is_float_type(variable):
  348. # attach grad only to floating types
  349. variable.requires_grad = True
  350. return variable
  351. def to_numpy(self, x):
  352. return x.detach().cpu().numpy()
  353. def arange(self, start, stop):
  354. return self.flow.arange(start, stop, dtype=self.flow.int64)
  355. def reduce(self, x, operation, reduced_axes):
  356. for axis in sorted(reduced_axes, reverse=True):
  357. if operation == "min":
  358. x, _ = x.min(dim=axis)
  359. elif operation == "max":
  360. x, _ = x.max(dim=axis)
  361. elif operation in ["sum", "mean", "prod", "any", "all"]:
  362. x = getattr(x, operation)(dim=axis)
  363. else:
  364. raise NotImplementedError("Unknown reduction ", operation)
  365. return x
  366. def transpose(self, x, axes):
  367. return x.permute(axes)
  368. def stack_on_zeroth_dimension(self, tensors: list):
  369. return self.flow.stack(tensors)
  370. def add_axes(self, x, n_axes, pos2len):
  371. repeats = [-1] * n_axes
  372. for axis_position, axis_length in pos2len.items():
  373. x = self.add_axis(x, axis_position)
  374. repeats[axis_position] = axis_length
  375. return x.expand(*repeats)
  376. def tile(self, x, repeats):
  377. return x.repeat(repeats)
  378. def concat(self, tensors, axis: int):
  379. return self.flow.concat(tensors, dim=axis)
  380. def add_axis(self, x, new_position):
  381. return self.flow.unsqueeze(x, new_position)
  382. def is_float_type(self, x):
  383. return x.dtype in [self.flow.float16, self.flow.float32, self.flow.float64]
  384. def layers(self):
  385. from .layers import oneflow
  386. return oneflow
  387. def einsum(self, pattern, *x):
  388. return self.flow.einsum(pattern, *x)
  389. class PaddleBackend(AbstractBackend):
  390. framework_name = "paddle"
  391. def __init__(self):
  392. import paddle
  393. self.paddle = paddle
  394. def is_appropriate_type(self, tensor):
  395. return self.paddle.is_tensor(tensor)
  396. def from_numpy(self, x):
  397. tensor = self.paddle.to_tensor(x)
  398. tensor.stop_gradient = False
  399. return tensor
  400. def to_numpy(self, x):
  401. return x.detach().numpy()
  402. def arange(self, start, stop):
  403. return self.paddle.arange(start, stop, dtype=self.paddle.int64)
  404. def reduce(self, x, operation, axes):
  405. if len(axes) == x.ndim:
  406. # currently paddle returns 1d tensor instead of 0d
  407. return super().reduce(x, operation, axes).squeeze(0)
  408. else:
  409. return super().reduce(x, operation, axes)
  410. def transpose(self, x, axes):
  411. return x.transpose(axes)
  412. def add_axes(self, x, n_axes, pos2len):
  413. repeats = [-1] * n_axes
  414. for axis_position, axis_length in pos2len.items():
  415. x = self.add_axis(x, axis_position)
  416. repeats[axis_position] = axis_length
  417. return x.expand(repeats)
  418. def stack_on_zeroth_dimension(self, tensors: list):
  419. return self.paddle.stack(tensors)
  420. def reshape(self, x, shape):
  421. return x.reshape(shape)
  422. def tile(self, x, repeats):
  423. return x.tile(repeats)
  424. def concat(self, tensors, axis: int):
  425. return self.paddle.concat(tensors, axis=axis)
  426. def add_axis(self, x, new_position):
  427. return x.unsqueeze(new_position)
  428. def is_float_type(self, x):
  429. return x.dtype in [self.paddle.float16, self.paddle.float32, self.paddle.float64]
  430. def layers(self):
  431. from .layers import paddle
  432. return paddle
  433. def einsum(self, pattern, *x):
  434. return self.paddle.einsum(pattern, *x)
  435. def shape(self, x):
  436. return tuple(x.shape)
  437. class TinygradBackend(AbstractBackend):
  438. framework_name = "tinygrad"
  439. def __init__(self):
  440. import tinygrad
  441. self.tinygrad = tinygrad
  442. def is_appropriate_type(self, tensor):
  443. return isinstance(tensor, self.tinygrad.Tensor)
  444. def from_numpy(self, x):
  445. return self.tinygrad.Tensor(x)
  446. def to_numpy(self, x):
  447. return x.numpy()
  448. def arange(self, start, stop):
  449. return self.tinygrad.Tensor.arange(start, stop)
  450. def shape(self, x):
  451. return x.shape
  452. def reshape(self, x, shape):
  453. return x.reshape(shape)
  454. def transpose(self, x, axes):
  455. return x.permute(axes)
  456. def reduce(self, x, operation, axes):
  457. for axis in sorted(axes, reverse=True):
  458. x = getattr(x, operation)(axis=axis)
  459. return x
  460. def stack_on_zeroth_dimension(self, tensors: list):
  461. return self.tinygrad.Tensor.stack(tensors)
  462. def add_axis(self, x, new_position):
  463. return x.unsqueeze(new_position)
  464. def tile(self, x, repeats):
  465. return x.repeat(repeats)
  466. def concat(self, tensors, axis: int):
  467. return tensors[0].cat(*tensors[1:], dim=axis) if len(tensors) > 1 else tensors[0]
  468. def is_float_type(self, x):
  469. return self.tinygrad.dtypes.is_float(x.dtype)
  470. def einsum(self, pattern, *x):
  471. return self.tinygrad.Tensor.einsum(pattern, *x)
  472. class PyTensorBackend(AbstractBackend):
  473. framework_name = "pytensor"
  474. def __init__(self):
  475. from pytensor import tensor
  476. self.pt = tensor
  477. def is_appropriate_type(self, tensor):
  478. return isinstance(tensor, self.pt.TensorVariable)
  479. def is_float_type(self, x):
  480. return x.dtype in self.pt.type.float_dtypes
  481. def from_numpy(self, x):
  482. return self.pt.as_tensor(x)
  483. def to_numpy(self, x):
  484. return x.eval() # Will only work if there are no symbolic inputs
  485. def create_symbol(self, shape):
  486. if not isinstance(shape, tuple | list):
  487. shape = (shape,)
  488. return self.pt.tensor(shape=shape)
  489. def eval_symbol(self, symbol, symbol_value_pairs):
  490. return symbol.eval(dict(symbol_value_pairs))
  491. def arange(self, start, stop):
  492. return self.pt.arange(start, stop)
  493. def shape(self, x):
  494. # use the static shape dimensions where known
  495. return tuple(
  496. static_dim if static_dim is not None else symbolic_dim
  497. for static_dim, symbolic_dim in zip(x.type.shape, x.shape)
  498. )
  499. def stack_on_zeroth_dimension(self, tensors: list):
  500. return self.pt.stack(tensors)
  501. def tile(self, x, repeats):
  502. return self.pt.tile(x, repeats)
  503. def concat(self, tensors, axis: int):
  504. return self.pt.concatenate(tensors, axis=axis)
  505. def add_axis(self, x, new_position):
  506. return self.pt.expand_dims(x, new_position)
  507. def einsum(self, pattern, *x):
  508. return self.pt.einsum(pattern, *x)
  509. class MLXBackend(AbstractBackend):
  510. framework_name = "mlx"
  511. def __init__(self):
  512. import mlx.core as mx
  513. import numpy as np
  514. self.mx = mx
  515. self.np = np
  516. def is_appropriate_type(self, tensor):
  517. return isinstance(tensor, self.mx.array)
  518. def from_numpy(self, x):
  519. return self.mx.array(x)
  520. def to_numpy(self, x):
  521. if x.dtype == self.mx.bfloat16:
  522. x = x.astype(self.mx.float32)
  523. return self.np.array(x)
  524. def arange(self, start, stop):
  525. return self.mx.arange(start, stop)
  526. def stack_on_zeroth_dimension(self, tensors: list):
  527. return self.mx.stack(tensors)
  528. def add_axes(self, x, new_position):
  529. return self.mx.expand_dims(x, new_position)
  530. def tile(self, x, repeats):
  531. return self.mx.tile(x, repeats)
  532. def concat(self, tensors, axis: int):
  533. return self.mx.concatenate(tensors, axis=axis)
  534. def is_float_type(self, x):
  535. return self.mx.issubdtype(x.dtype, self.mx.floating)
  536. def einsum(self, pattern, *x):
  537. return self.mx.einsum(pattern, *x)