einsumfunc.py 52 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499
  1. """
  2. Implementation of optimized einsum.
  3. """
  4. import itertools
  5. import operator
  6. from numpy._core.multiarray import c_einsum
  7. from numpy._core.numeric import asanyarray, tensordot
  8. from numpy._core.overrides import array_function_dispatch
  9. __all__ = ['einsum', 'einsum_path']
  10. # importing string for string.ascii_letters would be too slow
  11. # the first import before caching has been measured to take 800 µs (#23777)
  12. einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
  13. einsum_symbols_set = set(einsum_symbols)
  14. def _flop_count(idx_contraction, inner, num_terms, size_dictionary):
  15. """
  16. Computes the number of FLOPS in the contraction.
  17. Parameters
  18. ----------
  19. idx_contraction : iterable
  20. The indices involved in the contraction
  21. inner : bool
  22. Does this contraction require an inner product?
  23. num_terms : int
  24. The number of terms in a contraction
  25. size_dictionary : dict
  26. The size of each of the indices in idx_contraction
  27. Returns
  28. -------
  29. flop_count : int
  30. The total number of FLOPS required for the contraction.
  31. Examples
  32. --------
  33. >>> _flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5})
  34. 30
  35. >>> _flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5})
  36. 60
  37. """
  38. overall_size = _compute_size_by_dict(idx_contraction, size_dictionary)
  39. op_factor = max(1, num_terms - 1)
  40. if inner:
  41. op_factor += 1
  42. return overall_size * op_factor
  43. def _compute_size_by_dict(indices, idx_dict):
  44. """
  45. Computes the product of the elements in indices based on the dictionary
  46. idx_dict.
  47. Parameters
  48. ----------
  49. indices : iterable
  50. Indices to base the product on.
  51. idx_dict : dictionary
  52. Dictionary of index sizes
  53. Returns
  54. -------
  55. ret : int
  56. The resulting product.
  57. Examples
  58. --------
  59. >>> _compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5})
  60. 90
  61. """
  62. ret = 1
  63. for i in indices:
  64. ret *= idx_dict[i]
  65. return ret
  66. def _find_contraction(positions, input_sets, output_set):
  67. """
  68. Finds the contraction for a given set of input and output sets.
  69. Parameters
  70. ----------
  71. positions : iterable
  72. Integer positions of terms used in the contraction.
  73. input_sets : list
  74. List of sets that represent the lhs side of the einsum subscript
  75. output_set : set
  76. Set that represents the rhs side of the overall einsum subscript
  77. Returns
  78. -------
  79. new_result : set
  80. The indices of the resulting contraction
  81. remaining : list
  82. List of sets that have not been contracted, the new set is appended to
  83. the end of this list
  84. idx_removed : set
  85. Indices removed from the entire contraction
  86. idx_contraction : set
  87. The indices used in the current contraction
  88. Examples
  89. --------
  90. # A simple dot product test case
  91. >>> pos = (0, 1)
  92. >>> isets = [set('ab'), set('bc')]
  93. >>> oset = set('ac')
  94. >>> _find_contraction(pos, isets, oset)
  95. ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'})
  96. # A more complex case with additional terms in the contraction
  97. >>> pos = (0, 2)
  98. >>> isets = [set('abd'), set('ac'), set('bdc')]
  99. >>> oset = set('ac')
  100. >>> _find_contraction(pos, isets, oset)
  101. ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'})
  102. """
  103. idx_contract = set()
  104. idx_remain = output_set.copy()
  105. remaining = []
  106. for ind, value in enumerate(input_sets):
  107. if ind in positions:
  108. idx_contract |= value
  109. else:
  110. remaining.append(value)
  111. idx_remain |= value
  112. new_result = idx_remain & idx_contract
  113. idx_removed = (idx_contract - new_result)
  114. remaining.append(new_result)
  115. return (new_result, remaining, idx_removed, idx_contract)
  116. def _optimal_path(input_sets, output_set, idx_dict, memory_limit):
  117. """
  118. Computes all possible pair contractions, sieves the results based
  119. on ``memory_limit`` and returns the lowest cost path. This algorithm
  120. scales factorial with respect to the elements in the list ``input_sets``.
  121. Parameters
  122. ----------
  123. input_sets : list
  124. List of sets that represent the lhs side of the einsum subscript
  125. output_set : set
  126. Set that represents the rhs side of the overall einsum subscript
  127. idx_dict : dictionary
  128. Dictionary of index sizes
  129. memory_limit : int
  130. The maximum number of elements in a temporary array
  131. Returns
  132. -------
  133. path : list
  134. The optimal contraction order within the memory limit constraint.
  135. Examples
  136. --------
  137. >>> isets = [set('abd'), set('ac'), set('bdc')]
  138. >>> oset = set()
  139. >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
  140. >>> _optimal_path(isets, oset, idx_sizes, 5000)
  141. [(0, 2), (0, 1)]
  142. """
  143. full_results = [(0, [], input_sets)]
  144. for iteration in range(len(input_sets) - 1):
  145. iter_results = []
  146. # Compute all unique pairs
  147. for curr in full_results:
  148. cost, positions, remaining = curr
  149. for con in itertools.combinations(
  150. range(len(input_sets) - iteration), 2
  151. ):
  152. # Find the contraction
  153. cont = _find_contraction(con, remaining, output_set)
  154. new_result, new_input_sets, idx_removed, idx_contract = cont
  155. # Sieve the results based on memory_limit
  156. new_size = _compute_size_by_dict(new_result, idx_dict)
  157. if new_size > memory_limit:
  158. continue
  159. # Build (total_cost, positions, indices_remaining)
  160. total_cost = cost + _flop_count(
  161. idx_contract, idx_removed, len(con), idx_dict
  162. )
  163. new_pos = positions + [con]
  164. iter_results.append((total_cost, new_pos, new_input_sets))
  165. # Update combinatorial list, if we did not find anything return best
  166. # path + remaining contractions
  167. if iter_results:
  168. full_results = iter_results
  169. else:
  170. path = min(full_results, key=lambda x: x[0])[1]
  171. path += [tuple(range(len(input_sets) - iteration))]
  172. return path
  173. # If we have not found anything return single einsum contraction
  174. if len(full_results) == 0:
  175. return [tuple(range(len(input_sets)))]
  176. path = min(full_results, key=lambda x: x[0])[1]
  177. return path
  178. def _parse_possible_contraction(
  179. positions, input_sets, output_set, idx_dict,
  180. memory_limit, path_cost, naive_cost
  181. ):
  182. """Compute the cost (removed size + flops) and resultant indices for
  183. performing the contraction specified by ``positions``.
  184. Parameters
  185. ----------
  186. positions : tuple of int
  187. The locations of the proposed tensors to contract.
  188. input_sets : list of sets
  189. The indices found on each tensors.
  190. output_set : set
  191. The output indices of the expression.
  192. idx_dict : dict
  193. Mapping of each index to its size.
  194. memory_limit : int
  195. The total allowed size for an intermediary tensor.
  196. path_cost : int
  197. The contraction cost so far.
  198. naive_cost : int
  199. The cost of the unoptimized expression.
  200. Returns
  201. -------
  202. cost : (int, int)
  203. A tuple containing the size of any indices removed, and the flop cost.
  204. positions : tuple of int
  205. The locations of the proposed tensors to contract.
  206. new_input_sets : list of sets
  207. The resulting new list of indices if this proposed contraction
  208. is performed.
  209. """
  210. # Find the contraction
  211. contract = _find_contraction(positions, input_sets, output_set)
  212. idx_result, new_input_sets, idx_removed, idx_contract = contract
  213. # Sieve the results based on memory_limit
  214. new_size = _compute_size_by_dict(idx_result, idx_dict)
  215. if new_size > memory_limit:
  216. return None
  217. # Build sort tuple
  218. old_sizes = (
  219. _compute_size_by_dict(input_sets[p], idx_dict) for p in positions
  220. )
  221. removed_size = sum(old_sizes) - new_size
  222. # NB: removed_size used to be just the size of any removed indices i.e.:
  223. # helpers.compute_size_by_dict(idx_removed, idx_dict)
  224. cost = _flop_count(idx_contract, idx_removed, len(positions), idx_dict)
  225. sort = (-removed_size, cost)
  226. # Sieve based on total cost as well
  227. if (path_cost + cost) > naive_cost:
  228. return None
  229. # Add contraction to possible choices
  230. return [sort, positions, new_input_sets]
  231. def _update_other_results(results, best):
  232. """Update the positions and provisional input_sets of ``results``
  233. based on performing the contraction result ``best``. Remove any
  234. involving the tensors contracted.
  235. Parameters
  236. ----------
  237. results : list
  238. List of contraction results produced by
  239. ``_parse_possible_contraction``.
  240. best : list
  241. The best contraction of ``results`` i.e. the one that
  242. will be performed.
  243. Returns
  244. -------
  245. mod_results : list
  246. The list of modified results, updated with outcome of
  247. ``best`` contraction.
  248. """
  249. best_con = best[1]
  250. bx, by = best_con
  251. mod_results = []
  252. for cost, (x, y), con_sets in results:
  253. # Ignore results involving tensors just contracted
  254. if x in best_con or y in best_con:
  255. continue
  256. # Update the input_sets
  257. del con_sets[by - int(by > x) - int(by > y)]
  258. del con_sets[bx - int(bx > x) - int(bx > y)]
  259. con_sets.insert(-1, best[2][-1])
  260. # Update the position indices
  261. mod_con = x - int(x > bx) - int(x > by), y - int(y > bx) - int(y > by)
  262. mod_results.append((cost, mod_con, con_sets))
  263. return mod_results
  264. def _greedy_path(input_sets, output_set, idx_dict, memory_limit):
  265. """
  266. Finds the path by contracting the best pair until the input list is
  267. exhausted. The best pair is found by minimizing the tuple
  268. ``(-prod(indices_removed), cost)``. What this amounts to is prioritizing
  269. matrix multiplication or inner product operations, then Hadamard like
  270. operations, and finally outer operations. Outer products are limited by
  271. ``memory_limit``. This algorithm scales cubically with respect to the
  272. number of elements in the list ``input_sets``.
  273. Parameters
  274. ----------
  275. input_sets : list
  276. List of sets that represent the lhs side of the einsum subscript
  277. output_set : set
  278. Set that represents the rhs side of the overall einsum subscript
  279. idx_dict : dictionary
  280. Dictionary of index sizes
  281. memory_limit : int
  282. The maximum number of elements in a temporary array
  283. Returns
  284. -------
  285. path : list
  286. The greedy contraction order within the memory limit constraint.
  287. Examples
  288. --------
  289. >>> isets = [set('abd'), set('ac'), set('bdc')]
  290. >>> oset = set()
  291. >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
  292. >>> _greedy_path(isets, oset, idx_sizes, 5000)
  293. [(0, 2), (0, 1)]
  294. """
  295. # Handle trivial cases that leaked through
  296. if len(input_sets) == 1:
  297. return [(0,)]
  298. elif len(input_sets) == 2:
  299. return [(0, 1)]
  300. # Build up a naive cost
  301. contract = _find_contraction(
  302. range(len(input_sets)), input_sets, output_set
  303. )
  304. idx_result, new_input_sets, idx_removed, idx_contract = contract
  305. naive_cost = _flop_count(
  306. idx_contract, idx_removed, len(input_sets), idx_dict
  307. )
  308. # Initially iterate over all pairs
  309. comb_iter = itertools.combinations(range(len(input_sets)), 2)
  310. known_contractions = []
  311. path_cost = 0
  312. path = []
  313. for iteration in range(len(input_sets) - 1):
  314. # Iterate over all pairs on the first step, only previously
  315. # found pairs on subsequent steps
  316. for positions in comb_iter:
  317. # Always initially ignore outer products
  318. if input_sets[positions[0]].isdisjoint(input_sets[positions[1]]):
  319. continue
  320. result = _parse_possible_contraction(
  321. positions, input_sets, output_set, idx_dict,
  322. memory_limit, path_cost, naive_cost
  323. )
  324. if result is not None:
  325. known_contractions.append(result)
  326. # If we do not have a inner contraction, rescan pairs
  327. # including outer products
  328. if len(known_contractions) == 0:
  329. # Then check the outer products
  330. for positions in itertools.combinations(
  331. range(len(input_sets)), 2
  332. ):
  333. result = _parse_possible_contraction(
  334. positions, input_sets, output_set, idx_dict,
  335. memory_limit, path_cost, naive_cost
  336. )
  337. if result is not None:
  338. known_contractions.append(result)
  339. # If we still did not find any remaining contractions,
  340. # default back to einsum like behavior
  341. if len(known_contractions) == 0:
  342. path.append(tuple(range(len(input_sets))))
  343. break
  344. # Sort based on first index
  345. best = min(known_contractions, key=lambda x: x[0])
  346. # Now propagate as many unused contractions as possible
  347. # to the next iteration
  348. known_contractions = _update_other_results(known_contractions, best)
  349. # Next iteration only compute contractions with the new tensor
  350. # All other contractions have been accounted for
  351. input_sets = best[2]
  352. new_tensor_pos = len(input_sets) - 1
  353. comb_iter = ((i, new_tensor_pos) for i in range(new_tensor_pos))
  354. # Update path and total cost
  355. path.append(best[1])
  356. path_cost += best[0][1]
  357. return path
  358. def _can_dot(inputs, result, idx_removed):
  359. """
  360. Checks if we can use BLAS (np.tensordot) call and its beneficial to do so.
  361. Parameters
  362. ----------
  363. inputs : list of str
  364. Specifies the subscripts for summation.
  365. result : str
  366. Resulting summation.
  367. idx_removed : set
  368. Indices that are removed in the summation
  369. Returns
  370. -------
  371. type : bool
  372. Returns true if BLAS should and can be used, else False
  373. Notes
  374. -----
  375. If the operations is BLAS level 1 or 2 and is not already aligned
  376. we default back to einsum as the memory movement to copy is more
  377. costly than the operation itself.
  378. Examples
  379. --------
  380. # Standard GEMM operation
  381. >>> _can_dot(['ij', 'jk'], 'ik', set('j'))
  382. True
  383. # Can use the standard BLAS, but requires odd data movement
  384. >>> _can_dot(['ijj', 'jk'], 'ik', set('j'))
  385. False
  386. # DDOT where the memory is not aligned
  387. >>> _can_dot(['ijk', 'ikj'], '', set('ijk'))
  388. False
  389. """
  390. # All `dot` calls remove indices
  391. if len(idx_removed) == 0:
  392. return False
  393. # BLAS can only handle two operands
  394. if len(inputs) != 2:
  395. return False
  396. input_left, input_right = inputs
  397. for c in set(input_left + input_right):
  398. # can't deal with repeated indices on same input or more than 2 total
  399. nl, nr = input_left.count(c), input_right.count(c)
  400. if (nl > 1) or (nr > 1) or (nl + nr > 2):
  401. return False
  402. # can't do implicit summation or dimension collapse e.g.
  403. # "ab,bc->c" (implicitly sum over 'a')
  404. # "ab,ca->ca" (take diagonal of 'a')
  405. if nl + nr - 1 == int(c in result):
  406. return False
  407. # Build a few temporaries
  408. set_left = set(input_left)
  409. set_right = set(input_right)
  410. keep_left = set_left - idx_removed
  411. keep_right = set_right - idx_removed
  412. rs = len(idx_removed)
  413. # At this point we are a DOT, GEMV, or GEMM operation
  414. # Handle inner products
  415. # DDOT with aligned data
  416. if input_left == input_right:
  417. return True
  418. # DDOT without aligned data (better to use einsum)
  419. if set_left == set_right:
  420. return False
  421. # Handle the 4 possible (aligned) GEMV or GEMM cases
  422. # GEMM or GEMV no transpose
  423. if input_left[-rs:] == input_right[:rs]:
  424. return True
  425. # GEMM or GEMV transpose both
  426. if input_left[:rs] == input_right[-rs:]:
  427. return True
  428. # GEMM or GEMV transpose right
  429. if input_left[-rs:] == input_right[-rs:]:
  430. return True
  431. # GEMM or GEMV transpose left
  432. if input_left[:rs] == input_right[:rs]:
  433. return True
  434. # Einsum is faster than GEMV if we have to copy data
  435. if not keep_left or not keep_right:
  436. return False
  437. # We are a matrix-matrix product, but we need to copy data
  438. return True
  439. def _parse_einsum_input(operands):
  440. """
  441. A reproduction of einsum c side einsum parsing in python.
  442. Returns
  443. -------
  444. input_strings : str
  445. Parsed input strings
  446. output_string : str
  447. Parsed output string
  448. operands : list of array_like
  449. The operands to use in the numpy contraction
  450. Examples
  451. --------
  452. The operand list is simplified to reduce printing:
  453. >>> np.random.seed(123)
  454. >>> a = np.random.rand(4, 4)
  455. >>> b = np.random.rand(4, 4, 4)
  456. >>> _parse_einsum_input(('...a,...a->...', a, b))
  457. ('za,xza', 'xz', [a, b]) # may vary
  458. >>> _parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))
  459. ('za,xza', 'xz', [a, b]) # may vary
  460. """
  461. if len(operands) == 0:
  462. raise ValueError("No input operands")
  463. if isinstance(operands[0], str):
  464. subscripts = operands[0].replace(" ", "")
  465. operands = [asanyarray(v) for v in operands[1:]]
  466. # Ensure all characters are valid
  467. for s in subscripts:
  468. if s in '.,->':
  469. continue
  470. if s not in einsum_symbols:
  471. raise ValueError("Character %s is not a valid symbol." % s)
  472. else:
  473. tmp_operands = list(operands)
  474. operand_list = []
  475. subscript_list = []
  476. for p in range(len(operands) // 2):
  477. operand_list.append(tmp_operands.pop(0))
  478. subscript_list.append(tmp_operands.pop(0))
  479. output_list = tmp_operands[-1] if len(tmp_operands) else None
  480. operands = [asanyarray(v) for v in operand_list]
  481. subscripts = ""
  482. last = len(subscript_list) - 1
  483. for num, sub in enumerate(subscript_list):
  484. for s in sub:
  485. if s is Ellipsis:
  486. subscripts += "..."
  487. else:
  488. try:
  489. s = operator.index(s)
  490. except TypeError as e:
  491. raise TypeError(
  492. "For this input type lists must contain "
  493. "either int or Ellipsis"
  494. ) from e
  495. subscripts += einsum_symbols[s]
  496. if num != last:
  497. subscripts += ","
  498. if output_list is not None:
  499. subscripts += "->"
  500. for s in output_list:
  501. if s is Ellipsis:
  502. subscripts += "..."
  503. else:
  504. try:
  505. s = operator.index(s)
  506. except TypeError as e:
  507. raise TypeError(
  508. "For this input type lists must contain "
  509. "either int or Ellipsis"
  510. ) from e
  511. subscripts += einsum_symbols[s]
  512. # Check for proper "->"
  513. if ("-" in subscripts) or (">" in subscripts):
  514. invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
  515. if invalid or (subscripts.count("->") != 1):
  516. raise ValueError("Subscripts can only contain one '->'.")
  517. # Parse ellipses
  518. if "." in subscripts:
  519. used = subscripts.replace(".", "").replace(",", "").replace("->", "")
  520. unused = list(einsum_symbols_set - set(used))
  521. ellipse_inds = "".join(unused)
  522. longest = 0
  523. if "->" in subscripts:
  524. input_tmp, output_sub = subscripts.split("->")
  525. split_subscripts = input_tmp.split(",")
  526. out_sub = True
  527. else:
  528. split_subscripts = subscripts.split(',')
  529. out_sub = False
  530. for num, sub in enumerate(split_subscripts):
  531. if "." in sub:
  532. if (sub.count(".") != 3) or (sub.count("...") != 1):
  533. raise ValueError("Invalid Ellipses.")
  534. # Take into account numerical values
  535. if operands[num].shape == ():
  536. ellipse_count = 0
  537. else:
  538. ellipse_count = max(operands[num].ndim, 1)
  539. ellipse_count -= (len(sub) - 3)
  540. if ellipse_count > longest:
  541. longest = ellipse_count
  542. if ellipse_count < 0:
  543. raise ValueError("Ellipses lengths do not match.")
  544. elif ellipse_count == 0:
  545. split_subscripts[num] = sub.replace('...', '')
  546. else:
  547. rep_inds = ellipse_inds[-ellipse_count:]
  548. split_subscripts[num] = sub.replace('...', rep_inds)
  549. subscripts = ",".join(split_subscripts)
  550. if longest == 0:
  551. out_ellipse = ""
  552. else:
  553. out_ellipse = ellipse_inds[-longest:]
  554. if out_sub:
  555. subscripts += "->" + output_sub.replace("...", out_ellipse)
  556. else:
  557. # Special care for outputless ellipses
  558. output_subscript = ""
  559. tmp_subscripts = subscripts.replace(",", "")
  560. for s in sorted(set(tmp_subscripts)):
  561. if s not in (einsum_symbols):
  562. raise ValueError("Character %s is not a valid symbol." % s)
  563. if tmp_subscripts.count(s) == 1:
  564. output_subscript += s
  565. normal_inds = ''.join(sorted(set(output_subscript) -
  566. set(out_ellipse)))
  567. subscripts += "->" + out_ellipse + normal_inds
  568. # Build output string if does not exist
  569. if "->" in subscripts:
  570. input_subscripts, output_subscript = subscripts.split("->")
  571. else:
  572. input_subscripts = subscripts
  573. # Build output subscripts
  574. tmp_subscripts = subscripts.replace(",", "")
  575. output_subscript = ""
  576. for s in sorted(set(tmp_subscripts)):
  577. if s not in einsum_symbols:
  578. raise ValueError("Character %s is not a valid symbol." % s)
  579. if tmp_subscripts.count(s) == 1:
  580. output_subscript += s
  581. # Make sure output subscripts are in the input
  582. for char in output_subscript:
  583. if output_subscript.count(char) != 1:
  584. raise ValueError("Output character %s appeared more than once in "
  585. "the output." % char)
  586. if char not in input_subscripts:
  587. raise ValueError("Output character %s did not appear in the input"
  588. % char)
  589. # Make sure number operands is equivalent to the number of terms
  590. if len(input_subscripts.split(',')) != len(operands):
  591. raise ValueError("Number of einsum subscripts must be equal to the "
  592. "number of operands.")
  593. return (input_subscripts, output_subscript, operands)
  594. def _einsum_path_dispatcher(*operands, optimize=None, einsum_call=None):
  595. # NOTE: technically, we should only dispatch on array-like arguments, not
  596. # subscripts (given as strings). But separating operands into
  597. # arrays/subscripts is a little tricky/slow (given einsum's two supported
  598. # signatures), so as a practical shortcut we dispatch on everything.
  599. # Strings will be ignored for dispatching since they don't define
  600. # __array_function__.
  601. return operands
  602. @array_function_dispatch(_einsum_path_dispatcher, module='numpy')
  603. def einsum_path(*operands, optimize='greedy', einsum_call=False):
  604. """
  605. einsum_path(subscripts, *operands, optimize='greedy')
  606. Evaluates the lowest cost contraction order for an einsum expression by
  607. considering the creation of intermediate arrays.
  608. Parameters
  609. ----------
  610. subscripts : str
  611. Specifies the subscripts for summation.
  612. *operands : list of array_like
  613. These are the arrays for the operation.
  614. optimize : {bool, list, tuple, 'greedy', 'optimal'}
  615. Choose the type of path. If a tuple is provided, the second argument is
  616. assumed to be the maximum intermediate size created. If only a single
  617. argument is provided the largest input or output array size is used
  618. as a maximum intermediate size.
  619. * if a list is given that starts with ``einsum_path``, uses this as the
  620. contraction path
  621. * if False no optimization is taken
  622. * if True defaults to the 'greedy' algorithm
  623. * 'optimal' An algorithm that combinatorially explores all possible
  624. ways of contracting the listed tensors and chooses the least costly
  625. path. Scales exponentially with the number of terms in the
  626. contraction.
  627. * 'greedy' An algorithm that chooses the best pair contraction
  628. at each step. Effectively, this algorithm searches the largest inner,
  629. Hadamard, and then outer products at each step. Scales cubically with
  630. the number of terms in the contraction. Equivalent to the 'optimal'
  631. path for most contractions.
  632. Default is 'greedy'.
  633. Returns
  634. -------
  635. path : list of tuples
  636. A list representation of the einsum path.
  637. string_repr : str
  638. A printable representation of the einsum path.
  639. Notes
  640. -----
  641. The resulting path indicates which terms of the input contraction should be
  642. contracted first, the result of this contraction is then appended to the
  643. end of the contraction list. This list can then be iterated over until all
  644. intermediate contractions are complete.
  645. See Also
  646. --------
  647. einsum, linalg.multi_dot
  648. Examples
  649. --------
  650. We can begin with a chain dot example. In this case, it is optimal to
  651. contract the ``b`` and ``c`` tensors first as represented by the first
  652. element of the path ``(1, 2)``. The resulting tensor is added to the end
  653. of the contraction and the remaining contraction ``(0, 1)`` is then
  654. completed.
  655. >>> np.random.seed(123)
  656. >>> a = np.random.rand(2, 2)
  657. >>> b = np.random.rand(2, 5)
  658. >>> c = np.random.rand(5, 2)
  659. >>> path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy')
  660. >>> print(path_info[0])
  661. ['einsum_path', (1, 2), (0, 1)]
  662. >>> print(path_info[1])
  663. Complete contraction: ij,jk,kl->il # may vary
  664. Naive scaling: 4
  665. Optimized scaling: 3
  666. Naive FLOP count: 1.600e+02
  667. Optimized FLOP count: 5.600e+01
  668. Theoretical speedup: 2.857
  669. Largest intermediate: 4.000e+00 elements
  670. -------------------------------------------------------------------------
  671. scaling current remaining
  672. -------------------------------------------------------------------------
  673. 3 kl,jk->jl ij,jl->il
  674. 3 jl,ij->il il->il
  675. A more complex index transformation example.
  676. >>> I = np.random.rand(10, 10, 10, 10)
  677. >>> C = np.random.rand(10, 10)
  678. >>> path_info = np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C,
  679. ... optimize='greedy')
  680. >>> print(path_info[0])
  681. ['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)]
  682. >>> print(path_info[1])
  683. Complete contraction: ea,fb,abcd,gc,hd->efgh # may vary
  684. Naive scaling: 8
  685. Optimized scaling: 5
  686. Naive FLOP count: 8.000e+08
  687. Optimized FLOP count: 8.000e+05
  688. Theoretical speedup: 1000.000
  689. Largest intermediate: 1.000e+04 elements
  690. --------------------------------------------------------------------------
  691. scaling current remaining
  692. --------------------------------------------------------------------------
  693. 5 abcd,ea->bcde fb,gc,hd,bcde->efgh
  694. 5 bcde,fb->cdef gc,hd,cdef->efgh
  695. 5 cdef,gc->defg hd,defg->efgh
  696. 5 defg,hd->efgh efgh->efgh
  697. """
  698. # Figure out what the path really is
  699. path_type = optimize
  700. if path_type is True:
  701. path_type = 'greedy'
  702. if path_type is None:
  703. path_type = False
  704. explicit_einsum_path = False
  705. memory_limit = None
  706. # No optimization or a named path algorithm
  707. if (path_type is False) or isinstance(path_type, str):
  708. pass
  709. # Given an explicit path
  710. elif len(path_type) and (path_type[0] == 'einsum_path'):
  711. explicit_einsum_path = True
  712. # Path tuple with memory limit
  713. elif ((len(path_type) == 2) and isinstance(path_type[0], str) and
  714. isinstance(path_type[1], (int, float))):
  715. memory_limit = int(path_type[1])
  716. path_type = path_type[0]
  717. else:
  718. raise TypeError("Did not understand the path: %s" % str(path_type))
  719. # Hidden option, only einsum should call this
  720. einsum_call_arg = einsum_call
  721. # Python side parsing
  722. input_subscripts, output_subscript, operands = (
  723. _parse_einsum_input(operands)
  724. )
  725. # Build a few useful list and sets
  726. input_list = input_subscripts.split(',')
  727. input_sets = [set(x) for x in input_list]
  728. output_set = set(output_subscript)
  729. indices = set(input_subscripts.replace(',', ''))
  730. # Get length of each unique dimension and ensure all dimensions are correct
  731. dimension_dict = {}
  732. broadcast_indices = [[] for x in range(len(input_list))]
  733. for tnum, term in enumerate(input_list):
  734. sh = operands[tnum].shape
  735. if len(sh) != len(term):
  736. raise ValueError("Einstein sum subscript %s does not contain the "
  737. "correct number of indices for operand %d."
  738. % (input_subscripts[tnum], tnum))
  739. for cnum, char in enumerate(term):
  740. dim = sh[cnum]
  741. # Build out broadcast indices
  742. if dim == 1:
  743. broadcast_indices[tnum].append(char)
  744. if char in dimension_dict.keys():
  745. # For broadcasting cases we always want the largest dim size
  746. if dimension_dict[char] == 1:
  747. dimension_dict[char] = dim
  748. elif dim not in (1, dimension_dict[char]):
  749. raise ValueError("Size of label '%s' for operand %d (%d) "
  750. "does not match previous terms (%d)."
  751. % (char, tnum, dimension_dict[char], dim))
  752. else:
  753. dimension_dict[char] = dim
  754. # Convert broadcast inds to sets
  755. broadcast_indices = [set(x) for x in broadcast_indices]
  756. # Compute size of each input array plus the output array
  757. size_list = [_compute_size_by_dict(term, dimension_dict)
  758. for term in input_list + [output_subscript]]
  759. max_size = max(size_list)
  760. if memory_limit is None:
  761. memory_arg = max_size
  762. else:
  763. memory_arg = memory_limit
  764. # Compute naive cost
  765. # This isn't quite right, need to look into exactly how einsum does this
  766. inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0
  767. naive_cost = _flop_count(
  768. indices, inner_product, len(input_list), dimension_dict
  769. )
  770. # Compute the path
  771. if explicit_einsum_path:
  772. path = path_type[1:]
  773. elif (
  774. (path_type is False)
  775. or (len(input_list) in [1, 2])
  776. or (indices == output_set)
  777. ):
  778. # Nothing to be optimized, leave it to einsum
  779. path = [tuple(range(len(input_list)))]
  780. elif path_type == "greedy":
  781. path = _greedy_path(
  782. input_sets, output_set, dimension_dict, memory_arg
  783. )
  784. elif path_type == "optimal":
  785. path = _optimal_path(
  786. input_sets, output_set, dimension_dict, memory_arg
  787. )
  788. else:
  789. raise KeyError("Path name %s not found", path_type)
  790. cost_list, scale_list, size_list, contraction_list = [], [], [], []
  791. # Build contraction tuple (positions, gemm, einsum_str, remaining)
  792. for cnum, contract_inds in enumerate(path):
  793. # Make sure we remove inds from right to left
  794. contract_inds = tuple(sorted(contract_inds, reverse=True))
  795. contract = _find_contraction(contract_inds, input_sets, output_set)
  796. out_inds, input_sets, idx_removed, idx_contract = contract
  797. cost = _flop_count(
  798. idx_contract, idx_removed, len(contract_inds), dimension_dict
  799. )
  800. cost_list.append(cost)
  801. scale_list.append(len(idx_contract))
  802. size_list.append(_compute_size_by_dict(out_inds, dimension_dict))
  803. bcast = set()
  804. tmp_inputs = []
  805. for x in contract_inds:
  806. tmp_inputs.append(input_list.pop(x))
  807. bcast |= broadcast_indices.pop(x)
  808. new_bcast_inds = bcast - idx_removed
  809. # If we're broadcasting, nix blas
  810. if not len(idx_removed & bcast):
  811. do_blas = _can_dot(tmp_inputs, out_inds, idx_removed)
  812. else:
  813. do_blas = False
  814. # Last contraction
  815. if (cnum - len(path)) == -1:
  816. idx_result = output_subscript
  817. else:
  818. sort_result = [(dimension_dict[ind], ind) for ind in out_inds]
  819. idx_result = "".join([x[1] for x in sorted(sort_result)])
  820. input_list.append(idx_result)
  821. broadcast_indices.append(new_bcast_inds)
  822. einsum_str = ",".join(tmp_inputs) + "->" + idx_result
  823. contraction = (
  824. contract_inds, idx_removed, einsum_str, input_list[:], do_blas
  825. )
  826. contraction_list.append(contraction)
  827. opt_cost = sum(cost_list) + 1
  828. if len(input_list) != 1:
  829. # Explicit "einsum_path" is usually trusted, but we detect this kind of
  830. # mistake in order to prevent from returning an intermediate value.
  831. raise RuntimeError(
  832. "Invalid einsum_path is specified: {} more operands has to be "
  833. "contracted.".format(len(input_list) - 1))
  834. if einsum_call_arg:
  835. return (operands, contraction_list)
  836. # Return the path along with a nice string representation
  837. overall_contraction = input_subscripts + "->" + output_subscript
  838. header = ("scaling", "current", "remaining")
  839. speedup = naive_cost / opt_cost
  840. max_i = max(size_list)
  841. path_print = " Complete contraction: %s\n" % overall_contraction
  842. path_print += " Naive scaling: %d\n" % len(indices)
  843. path_print += " Optimized scaling: %d\n" % max(scale_list)
  844. path_print += " Naive FLOP count: %.3e\n" % naive_cost
  845. path_print += " Optimized FLOP count: %.3e\n" % opt_cost
  846. path_print += " Theoretical speedup: %3.3f\n" % speedup
  847. path_print += " Largest intermediate: %.3e elements\n" % max_i
  848. path_print += "-" * 74 + "\n"
  849. path_print += "%6s %24s %40s\n" % header
  850. path_print += "-" * 74
  851. for n, contraction in enumerate(contraction_list):
  852. inds, idx_rm, einsum_str, remaining, blas = contraction
  853. remaining_str = ",".join(remaining) + "->" + output_subscript
  854. path_run = (scale_list[n], einsum_str, remaining_str)
  855. path_print += "\n%4d %24s %40s" % path_run
  856. path = ['einsum_path'] + path
  857. return (path, path_print)
  858. def _einsum_dispatcher(*operands, out=None, optimize=None, **kwargs):
  859. # Arguably we dispatch on more arguments than we really should; see note in
  860. # _einsum_path_dispatcher for why.
  861. yield from operands
  862. yield out
  863. # Rewrite einsum to handle different cases
  864. @array_function_dispatch(_einsum_dispatcher, module='numpy')
  865. def einsum(*operands, out=None, optimize=False, **kwargs):
  866. """
  867. einsum(subscripts, *operands, out=None, dtype=None, order='K',
  868. casting='safe', optimize=False)
  869. Evaluates the Einstein summation convention on the operands.
  870. Using the Einstein summation convention, many common multi-dimensional,
  871. linear algebraic array operations can be represented in a simple fashion.
  872. In *implicit* mode `einsum` computes these values.
  873. In *explicit* mode, `einsum` provides further flexibility to compute
  874. other array operations that might not be considered classical Einstein
  875. summation operations, by disabling, or forcing summation over specified
  876. subscript labels.
  877. See the notes and examples for clarification.
  878. Parameters
  879. ----------
  880. subscripts : str
  881. Specifies the subscripts for summation as comma separated list of
  882. subscript labels. An implicit (classical Einstein summation)
  883. calculation is performed unless the explicit indicator '->' is
  884. included as well as subscript labels of the precise output form.
  885. operands : list of array_like
  886. These are the arrays for the operation.
  887. out : ndarray, optional
  888. If provided, the calculation is done into this array.
  889. dtype : {data-type, None}, optional
  890. If provided, forces the calculation to use the data type specified.
  891. Note that you may have to also give a more liberal `casting`
  892. parameter to allow the conversions. Default is None.
  893. order : {'C', 'F', 'A', 'K'}, optional
  894. Controls the memory layout of the output. 'C' means it should
  895. be C contiguous. 'F' means it should be Fortran contiguous,
  896. 'A' means it should be 'F' if the inputs are all 'F', 'C' otherwise.
  897. 'K' means it should be as close to the layout as the inputs as
  898. is possible, including arbitrarily permuted axes.
  899. Default is 'K'.
  900. casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
  901. Controls what kind of data casting may occur. Setting this to
  902. 'unsafe' is not recommended, as it can adversely affect accumulations.
  903. * 'no' means the data types should not be cast at all.
  904. * 'equiv' means only byte-order changes are allowed.
  905. * 'safe' means only casts which can preserve values are allowed.
  906. * 'same_kind' means only safe casts or casts within a kind,
  907. like float64 to float32, are allowed.
  908. * 'unsafe' means any data conversions may be done.
  909. Default is 'safe'.
  910. optimize : {False, True, 'greedy', 'optimal'}, optional
  911. Controls if intermediate optimization should occur. No optimization
  912. will occur if False and True will default to the 'greedy' algorithm.
  913. Also accepts an explicit contraction list from the ``np.einsum_path``
  914. function. See ``np.einsum_path`` for more details. Defaults to False.
  915. Returns
  916. -------
  917. output : ndarray
  918. The calculation based on the Einstein summation convention.
  919. See Also
  920. --------
  921. einsum_path, dot, inner, outer, tensordot, linalg.multi_dot
  922. einsum:
  923. Similar verbose interface is provided by the
  924. `einops <https://github.com/arogozhnikov/einops>`_ package to cover
  925. additional operations: transpose, reshape/flatten, repeat/tile,
  926. squeeze/unsqueeze and reductions.
  927. The `opt_einsum <https://optimized-einsum.readthedocs.io/en/stable/>`_
  928. optimizes contraction order for einsum-like expressions
  929. in backend-agnostic manner.
  930. Notes
  931. -----
  932. The Einstein summation convention can be used to compute
  933. many multi-dimensional, linear algebraic array operations. `einsum`
  934. provides a succinct way of representing these.
  935. A non-exhaustive list of these operations,
  936. which can be computed by `einsum`, is shown below along with examples:
  937. * Trace of an array, :py:func:`numpy.trace`.
  938. * Return a diagonal, :py:func:`numpy.diag`.
  939. * Array axis summations, :py:func:`numpy.sum`.
  940. * Transpositions and permutations, :py:func:`numpy.transpose`.
  941. * Matrix multiplication and dot product, :py:func:`numpy.matmul`
  942. :py:func:`numpy.dot`.
  943. * Vector inner and outer products, :py:func:`numpy.inner`
  944. :py:func:`numpy.outer`.
  945. * Broadcasting, element-wise and scalar multiplication,
  946. :py:func:`numpy.multiply`.
  947. * Tensor contractions, :py:func:`numpy.tensordot`.
  948. * Chained array operations, in efficient calculation order,
  949. :py:func:`numpy.einsum_path`.
  950. The subscripts string is a comma-separated list of subscript labels,
  951. where each label refers to a dimension of the corresponding operand.
  952. Whenever a label is repeated it is summed, so ``np.einsum('i,i', a, b)``
  953. is equivalent to :py:func:`np.inner(a,b) <numpy.inner>`. If a label
  954. appears only once, it is not summed, so ``np.einsum('i', a)``
  955. produces a view of ``a`` with no changes. A further example
  956. ``np.einsum('ij,jk', a, b)`` describes traditional matrix multiplication
  957. and is equivalent to :py:func:`np.matmul(a,b) <numpy.matmul>`.
  958. Repeated subscript labels in one operand take the diagonal.
  959. For example, ``np.einsum('ii', a)`` is equivalent to
  960. :py:func:`np.trace(a) <numpy.trace>`.
  961. In *implicit mode*, the chosen subscripts are important
  962. since the axes of the output are reordered alphabetically. This
  963. means that ``np.einsum('ij', a)`` doesn't affect a 2D array, while
  964. ``np.einsum('ji', a)`` takes its transpose. Additionally,
  965. ``np.einsum('ij,jk', a, b)`` returns a matrix multiplication, while,
  966. ``np.einsum('ij,jh', a, b)`` returns the transpose of the
  967. multiplication since subscript 'h' precedes subscript 'i'.
  968. In *explicit mode* the output can be directly controlled by
  969. specifying output subscript labels. This requires the
  970. identifier '->' as well as the list of output subscript labels.
  971. This feature increases the flexibility of the function since
  972. summing can be disabled or forced when required. The call
  973. ``np.einsum('i->', a)`` is like :py:func:`np.sum(a) <numpy.sum>`
  974. if ``a`` is a 1-D array, and ``np.einsum('ii->i', a)``
  975. is like :py:func:`np.diag(a) <numpy.diag>` if ``a`` is a square 2-D array.
  976. The difference is that `einsum` does not allow broadcasting by default.
  977. Additionally ``np.einsum('ij,jh->ih', a, b)`` directly specifies the
  978. order of the output subscript labels and therefore returns matrix
  979. multiplication, unlike the example above in implicit mode.
  980. To enable and control broadcasting, use an ellipsis. Default
  981. NumPy-style broadcasting is done by adding an ellipsis
  982. to the left of each term, like ``np.einsum('...ii->...i', a)``.
  983. ``np.einsum('...i->...', a)`` is like
  984. :py:func:`np.sum(a, axis=-1) <numpy.sum>` for array ``a`` of any shape.
  985. To take the trace along the first and last axes,
  986. you can do ``np.einsum('i...i', a)``, or to do a matrix-matrix
  987. product with the left-most indices instead of rightmost, one can do
  988. ``np.einsum('ij...,jk...->ik...', a, b)``.
  989. When there is only one operand, no axes are summed, and no output
  990. parameter is provided, a view into the operand is returned instead
  991. of a new array. Thus, taking the diagonal as ``np.einsum('ii->i', a)``
  992. produces a view (changed in version 1.10.0).
  993. `einsum` also provides an alternative way to provide the subscripts and
  994. operands as ``einsum(op0, sublist0, op1, sublist1, ..., [sublistout])``.
  995. If the output shape is not provided in this format `einsum` will be
  996. calculated in implicit mode, otherwise it will be performed explicitly.
  997. The examples below have corresponding `einsum` calls with the two
  998. parameter methods.
  999. Views returned from einsum are now writeable whenever the input array
  1000. is writeable. For example, ``np.einsum('ijk...->kji...', a)`` will now
  1001. have the same effect as :py:func:`np.swapaxes(a, 0, 2) <numpy.swapaxes>`
  1002. and ``np.einsum('ii->i', a)`` will return a writeable view of the diagonal
  1003. of a 2D array.
  1004. Added the ``optimize`` argument which will optimize the contraction order
  1005. of an einsum expression. For a contraction with three or more operands
  1006. this can greatly increase the computational efficiency at the cost of
  1007. a larger memory footprint during computation.
  1008. Typically a 'greedy' algorithm is applied which empirical tests have shown
  1009. returns the optimal path in the majority of cases. In some cases 'optimal'
  1010. will return the superlative path through a more expensive, exhaustive
  1011. search. For iterative calculations it may be advisable to calculate
  1012. the optimal path once and reuse that path by supplying it as an argument.
  1013. An example is given below.
  1014. See :py:func:`numpy.einsum_path` for more details.
  1015. Examples
  1016. --------
  1017. >>> a = np.arange(25).reshape(5,5)
  1018. >>> b = np.arange(5)
  1019. >>> c = np.arange(6).reshape(2,3)
  1020. Trace of a matrix:
  1021. >>> np.einsum('ii', a)
  1022. 60
  1023. >>> np.einsum(a, [0,0])
  1024. 60
  1025. >>> np.trace(a)
  1026. 60
  1027. Extract the diagonal (requires explicit form):
  1028. >>> np.einsum('ii->i', a)
  1029. array([ 0, 6, 12, 18, 24])
  1030. >>> np.einsum(a, [0,0], [0])
  1031. array([ 0, 6, 12, 18, 24])
  1032. >>> np.diag(a)
  1033. array([ 0, 6, 12, 18, 24])
  1034. Sum over an axis (requires explicit form):
  1035. >>> np.einsum('ij->i', a)
  1036. array([ 10, 35, 60, 85, 110])
  1037. >>> np.einsum(a, [0,1], [0])
  1038. array([ 10, 35, 60, 85, 110])
  1039. >>> np.sum(a, axis=1)
  1040. array([ 10, 35, 60, 85, 110])
  1041. For higher dimensional arrays summing a single axis can be done
  1042. with ellipsis:
  1043. >>> np.einsum('...j->...', a)
  1044. array([ 10, 35, 60, 85, 110])
  1045. >>> np.einsum(a, [Ellipsis,1], [Ellipsis])
  1046. array([ 10, 35, 60, 85, 110])
  1047. Compute a matrix transpose, or reorder any number of axes:
  1048. >>> np.einsum('ji', c)
  1049. array([[0, 3],
  1050. [1, 4],
  1051. [2, 5]])
  1052. >>> np.einsum('ij->ji', c)
  1053. array([[0, 3],
  1054. [1, 4],
  1055. [2, 5]])
  1056. >>> np.einsum(c, [1,0])
  1057. array([[0, 3],
  1058. [1, 4],
  1059. [2, 5]])
  1060. >>> np.transpose(c)
  1061. array([[0, 3],
  1062. [1, 4],
  1063. [2, 5]])
  1064. Vector inner products:
  1065. >>> np.einsum('i,i', b, b)
  1066. 30
  1067. >>> np.einsum(b, [0], b, [0])
  1068. 30
  1069. >>> np.inner(b,b)
  1070. 30
  1071. Matrix vector multiplication:
  1072. >>> np.einsum('ij,j', a, b)
  1073. array([ 30, 80, 130, 180, 230])
  1074. >>> np.einsum(a, [0,1], b, [1])
  1075. array([ 30, 80, 130, 180, 230])
  1076. >>> np.dot(a, b)
  1077. array([ 30, 80, 130, 180, 230])
  1078. >>> np.einsum('...j,j', a, b)
  1079. array([ 30, 80, 130, 180, 230])
  1080. Broadcasting and scalar multiplication:
  1081. >>> np.einsum('..., ...', 3, c)
  1082. array([[ 0, 3, 6],
  1083. [ 9, 12, 15]])
  1084. >>> np.einsum(',ij', 3, c)
  1085. array([[ 0, 3, 6],
  1086. [ 9, 12, 15]])
  1087. >>> np.einsum(3, [Ellipsis], c, [Ellipsis])
  1088. array([[ 0, 3, 6],
  1089. [ 9, 12, 15]])
  1090. >>> np.multiply(3, c)
  1091. array([[ 0, 3, 6],
  1092. [ 9, 12, 15]])
  1093. Vector outer product:
  1094. >>> np.einsum('i,j', np.arange(2)+1, b)
  1095. array([[0, 1, 2, 3, 4],
  1096. [0, 2, 4, 6, 8]])
  1097. >>> np.einsum(np.arange(2)+1, [0], b, [1])
  1098. array([[0, 1, 2, 3, 4],
  1099. [0, 2, 4, 6, 8]])
  1100. >>> np.outer(np.arange(2)+1, b)
  1101. array([[0, 1, 2, 3, 4],
  1102. [0, 2, 4, 6, 8]])
  1103. Tensor contraction:
  1104. >>> a = np.arange(60.).reshape(3,4,5)
  1105. >>> b = np.arange(24.).reshape(4,3,2)
  1106. >>> np.einsum('ijk,jil->kl', a, b)
  1107. array([[4400., 4730.],
  1108. [4532., 4874.],
  1109. [4664., 5018.],
  1110. [4796., 5162.],
  1111. [4928., 5306.]])
  1112. >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3])
  1113. array([[4400., 4730.],
  1114. [4532., 4874.],
  1115. [4664., 5018.],
  1116. [4796., 5162.],
  1117. [4928., 5306.]])
  1118. >>> np.tensordot(a,b, axes=([1,0],[0,1]))
  1119. array([[4400., 4730.],
  1120. [4532., 4874.],
  1121. [4664., 5018.],
  1122. [4796., 5162.],
  1123. [4928., 5306.]])
  1124. Writeable returned arrays (since version 1.10.0):
  1125. >>> a = np.zeros((3, 3))
  1126. >>> np.einsum('ii->i', a)[:] = 1
  1127. >>> a
  1128. array([[1., 0., 0.],
  1129. [0., 1., 0.],
  1130. [0., 0., 1.]])
  1131. Example of ellipsis use:
  1132. >>> a = np.arange(6).reshape((3,2))
  1133. >>> b = np.arange(12).reshape((4,3))
  1134. >>> np.einsum('ki,jk->ij', a, b)
  1135. array([[10, 28, 46, 64],
  1136. [13, 40, 67, 94]])
  1137. >>> np.einsum('ki,...k->i...', a, b)
  1138. array([[10, 28, 46, 64],
  1139. [13, 40, 67, 94]])
  1140. >>> np.einsum('k...,jk', a, b)
  1141. array([[10, 28, 46, 64],
  1142. [13, 40, 67, 94]])
  1143. Chained array operations. For more complicated contractions, speed ups
  1144. might be achieved by repeatedly computing a 'greedy' path or pre-computing
  1145. the 'optimal' path and repeatedly applying it, using an `einsum_path`
  1146. insertion (since version 1.12.0). Performance improvements can be
  1147. particularly significant with larger arrays:
  1148. >>> a = np.ones(64).reshape(2,4,8)
  1149. Basic `einsum`: ~1520ms (benchmarked on 3.1GHz Intel i5.)
  1150. >>> for iteration in range(500):
  1151. ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a)
  1152. Sub-optimal `einsum` (due to repeated path calculation time): ~330ms
  1153. >>> for iteration in range(500):
  1154. ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a,
  1155. ... optimize='optimal')
  1156. Greedy `einsum` (faster optimal path approximation): ~160ms
  1157. >>> for iteration in range(500):
  1158. ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy')
  1159. Optimal `einsum` (best usage pattern in some use cases): ~110ms
  1160. >>> path = np.einsum_path('ijk,ilm,njm,nlk,abc->',a,a,a,a,a,
  1161. ... optimize='optimal')[0]
  1162. >>> for iteration in range(500):
  1163. ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=path)
  1164. """
  1165. # Special handling if out is specified
  1166. specified_out = out is not None
  1167. # If no optimization, run pure einsum
  1168. if optimize is False:
  1169. if specified_out:
  1170. kwargs['out'] = out
  1171. return c_einsum(*operands, **kwargs)
  1172. # Check the kwargs to avoid a more cryptic error later, without having to
  1173. # repeat default values here
  1174. valid_einsum_kwargs = ['dtype', 'order', 'casting']
  1175. unknown_kwargs = [k for (k, v) in kwargs.items() if
  1176. k not in valid_einsum_kwargs]
  1177. if len(unknown_kwargs):
  1178. raise TypeError("Did not understand the following kwargs: %s"
  1179. % unknown_kwargs)
  1180. # Build the contraction list and operand
  1181. operands, contraction_list = einsum_path(*operands, optimize=optimize,
  1182. einsum_call=True)
  1183. # Handle order kwarg for output array, c_einsum allows mixed case
  1184. output_order = kwargs.pop('order', 'K')
  1185. if output_order.upper() == 'A':
  1186. if all(arr.flags.f_contiguous for arr in operands):
  1187. output_order = 'F'
  1188. else:
  1189. output_order = 'C'
  1190. # Start contraction loop
  1191. for num, contraction in enumerate(contraction_list):
  1192. inds, idx_rm, einsum_str, remaining, blas = contraction
  1193. tmp_operands = [operands.pop(x) for x in inds]
  1194. # Do we need to deal with the output?
  1195. handle_out = specified_out and ((num + 1) == len(contraction_list))
  1196. # Call tensordot if still possible
  1197. if blas:
  1198. # Checks have already been handled
  1199. input_str, results_index = einsum_str.split('->')
  1200. input_left, input_right = input_str.split(',')
  1201. tensor_result = input_left + input_right
  1202. for s in idx_rm:
  1203. tensor_result = tensor_result.replace(s, "")
  1204. # Find indices to contract over
  1205. left_pos, right_pos = [], []
  1206. for s in sorted(idx_rm):
  1207. left_pos.append(input_left.find(s))
  1208. right_pos.append(input_right.find(s))
  1209. # Contract!
  1210. new_view = tensordot(
  1211. *tmp_operands, axes=(tuple(left_pos), tuple(right_pos))
  1212. )
  1213. # Build a new view if needed
  1214. if (tensor_result != results_index) or handle_out:
  1215. if handle_out:
  1216. kwargs["out"] = out
  1217. new_view = c_einsum(
  1218. tensor_result + '->' + results_index, new_view, **kwargs
  1219. )
  1220. # Call einsum
  1221. else:
  1222. # If out was specified
  1223. if handle_out:
  1224. kwargs["out"] = out
  1225. # Do the contraction
  1226. new_view = c_einsum(einsum_str, *tmp_operands, **kwargs)
  1227. # Append new items and dereference what we can
  1228. operands.append(new_view)
  1229. del tmp_operands, new_view
  1230. if specified_out:
  1231. return out
  1232. else:
  1233. return asanyarray(operands[0], order=output_order)