einsumfunc.py 50 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415
  1. """
  2. Implementation of optimized einsum.
  3. """
  4. import itertools
  5. import operator
  6. from numpy.core.multiarray import c_einsum
  7. from numpy.core.numeric import asanyarray, tensordot
  8. from numpy.core.overrides import array_function_dispatch
  9. __all__ = ['einsum', 'einsum_path']
  10. einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
  11. einsum_symbols_set = set(einsum_symbols)
  12. def _flop_count(idx_contraction, inner, num_terms, size_dictionary):
  13. """
  14. Computes the number of FLOPS in the contraction.
  15. Parameters
  16. ----------
  17. idx_contraction : iterable
  18. The indices involved in the contraction
  19. inner : bool
  20. Does this contraction require an inner product?
  21. num_terms : int
  22. The number of terms in a contraction
  23. size_dictionary : dict
  24. The size of each of the indices in idx_contraction
  25. Returns
  26. -------
  27. flop_count : int
  28. The total number of FLOPS required for the contraction.
  29. Examples
  30. --------
  31. >>> _flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5})
  32. 30
  33. >>> _flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5})
  34. 60
  35. """
  36. overall_size = _compute_size_by_dict(idx_contraction, size_dictionary)
  37. op_factor = max(1, num_terms - 1)
  38. if inner:
  39. op_factor += 1
  40. return overall_size * op_factor
  41. def _compute_size_by_dict(indices, idx_dict):
  42. """
  43. Computes the product of the elements in indices based on the dictionary
  44. idx_dict.
  45. Parameters
  46. ----------
  47. indices : iterable
  48. Indices to base the product on.
  49. idx_dict : dictionary
  50. Dictionary of index sizes
  51. Returns
  52. -------
  53. ret : int
  54. The resulting product.
  55. Examples
  56. --------
  57. >>> _compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5})
  58. 90
  59. """
  60. ret = 1
  61. for i in indices:
  62. ret *= idx_dict[i]
  63. return ret
  64. def _find_contraction(positions, input_sets, output_set):
  65. """
  66. Finds the contraction for a given set of input and output sets.
  67. Parameters
  68. ----------
  69. positions : iterable
  70. Integer positions of terms used in the contraction.
  71. input_sets : list
  72. List of sets that represent the lhs side of the einsum subscript
  73. output_set : set
  74. Set that represents the rhs side of the overall einsum subscript
  75. Returns
  76. -------
  77. new_result : set
  78. The indices of the resulting contraction
  79. remaining : list
  80. List of sets that have not been contracted, the new set is appended to
  81. the end of this list
  82. idx_removed : set
  83. Indices removed from the entire contraction
  84. idx_contraction : set
  85. The indices used in the current contraction
  86. Examples
  87. --------
  88. # A simple dot product test case
  89. >>> pos = (0, 1)
  90. >>> isets = [set('ab'), set('bc')]
  91. >>> oset = set('ac')
  92. >>> _find_contraction(pos, isets, oset)
  93. ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'})
  94. # A more complex case with additional terms in the contraction
  95. >>> pos = (0, 2)
  96. >>> isets = [set('abd'), set('ac'), set('bdc')]
  97. >>> oset = set('ac')
  98. >>> _find_contraction(pos, isets, oset)
  99. ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'})
  100. """
  101. idx_contract = set()
  102. idx_remain = output_set.copy()
  103. remaining = []
  104. for ind, value in enumerate(input_sets):
  105. if ind in positions:
  106. idx_contract |= value
  107. else:
  108. remaining.append(value)
  109. idx_remain |= value
  110. new_result = idx_remain & idx_contract
  111. idx_removed = (idx_contract - new_result)
  112. remaining.append(new_result)
  113. return (new_result, remaining, idx_removed, idx_contract)
  114. def _optimal_path(input_sets, output_set, idx_dict, memory_limit):
  115. """
  116. Computes all possible pair contractions, sieves the results based
  117. on ``memory_limit`` and returns the lowest cost path. This algorithm
  118. scales factorial with respect to the elements in the list ``input_sets``.
  119. Parameters
  120. ----------
  121. input_sets : list
  122. List of sets that represent the lhs side of the einsum subscript
  123. output_set : set
  124. Set that represents the rhs side of the overall einsum subscript
  125. idx_dict : dictionary
  126. Dictionary of index sizes
  127. memory_limit : int
  128. The maximum number of elements in a temporary array
  129. Returns
  130. -------
  131. path : list
  132. The optimal contraction order within the memory limit constraint.
  133. Examples
  134. --------
  135. >>> isets = [set('abd'), set('ac'), set('bdc')]
  136. >>> oset = set()
  137. >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
  138. >>> _optimal_path(isets, oset, idx_sizes, 5000)
  139. [(0, 2), (0, 1)]
  140. """
  141. full_results = [(0, [], input_sets)]
  142. for iteration in range(len(input_sets) - 1):
  143. iter_results = []
  144. # Compute all unique pairs
  145. for curr in full_results:
  146. cost, positions, remaining = curr
  147. for con in itertools.combinations(range(len(input_sets) - iteration), 2):
  148. # Find the contraction
  149. cont = _find_contraction(con, remaining, output_set)
  150. new_result, new_input_sets, idx_removed, idx_contract = cont
  151. # Sieve the results based on memory_limit
  152. new_size = _compute_size_by_dict(new_result, idx_dict)
  153. if new_size > memory_limit:
  154. continue
  155. # Build (total_cost, positions, indices_remaining)
  156. total_cost = cost + _flop_count(idx_contract, idx_removed, len(con), idx_dict)
  157. new_pos = positions + [con]
  158. iter_results.append((total_cost, new_pos, new_input_sets))
  159. # Update combinatorial list, if we did not find anything return best
  160. # path + remaining contractions
  161. if iter_results:
  162. full_results = iter_results
  163. else:
  164. path = min(full_results, key=lambda x: x[0])[1]
  165. path += [tuple(range(len(input_sets) - iteration))]
  166. return path
  167. # If we have not found anything return single einsum contraction
  168. if len(full_results) == 0:
  169. return [tuple(range(len(input_sets)))]
  170. path = min(full_results, key=lambda x: x[0])[1]
  171. return path
  172. def _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost, naive_cost):
  173. """Compute the cost (removed size + flops) and resultant indices for
  174. performing the contraction specified by ``positions``.
  175. Parameters
  176. ----------
  177. positions : tuple of int
  178. The locations of the proposed tensors to contract.
  179. input_sets : list of sets
  180. The indices found on each tensors.
  181. output_set : set
  182. The output indices of the expression.
  183. idx_dict : dict
  184. Mapping of each index to its size.
  185. memory_limit : int
  186. The total allowed size for an intermediary tensor.
  187. path_cost : int
  188. The contraction cost so far.
  189. naive_cost : int
  190. The cost of the unoptimized expression.
  191. Returns
  192. -------
  193. cost : (int, int)
  194. A tuple containing the size of any indices removed, and the flop cost.
  195. positions : tuple of int
  196. The locations of the proposed tensors to contract.
  197. new_input_sets : list of sets
  198. The resulting new list of indices if this proposed contraction is performed.
  199. """
  200. # Find the contraction
  201. contract = _find_contraction(positions, input_sets, output_set)
  202. idx_result, new_input_sets, idx_removed, idx_contract = contract
  203. # Sieve the results based on memory_limit
  204. new_size = _compute_size_by_dict(idx_result, idx_dict)
  205. if new_size > memory_limit:
  206. return None
  207. # Build sort tuple
  208. old_sizes = (_compute_size_by_dict(input_sets[p], idx_dict) for p in positions)
  209. removed_size = sum(old_sizes) - new_size
  210. # NB: removed_size used to be just the size of any removed indices i.e.:
  211. # helpers.compute_size_by_dict(idx_removed, idx_dict)
  212. cost = _flop_count(idx_contract, idx_removed, len(positions), idx_dict)
  213. sort = (-removed_size, cost)
  214. # Sieve based on total cost as well
  215. if (path_cost + cost) > naive_cost:
  216. return None
  217. # Add contraction to possible choices
  218. return [sort, positions, new_input_sets]
  219. def _update_other_results(results, best):
  220. """Update the positions and provisional input_sets of ``results`` based on
  221. performing the contraction result ``best``. Remove any involving the tensors
  222. contracted.
  223. Parameters
  224. ----------
  225. results : list
  226. List of contraction results produced by ``_parse_possible_contraction``.
  227. best : list
  228. The best contraction of ``results`` i.e. the one that will be performed.
  229. Returns
  230. -------
  231. mod_results : list
  232. The list of modified results, updated with outcome of ``best`` contraction.
  233. """
  234. best_con = best[1]
  235. bx, by = best_con
  236. mod_results = []
  237. for cost, (x, y), con_sets in results:
  238. # Ignore results involving tensors just contracted
  239. if x in best_con or y in best_con:
  240. continue
  241. # Update the input_sets
  242. del con_sets[by - int(by > x) - int(by > y)]
  243. del con_sets[bx - int(bx > x) - int(bx > y)]
  244. con_sets.insert(-1, best[2][-1])
  245. # Update the position indices
  246. mod_con = x - int(x > bx) - int(x > by), y - int(y > bx) - int(y > by)
  247. mod_results.append((cost, mod_con, con_sets))
  248. return mod_results
  249. def _greedy_path(input_sets, output_set, idx_dict, memory_limit):
  250. """
  251. Finds the path by contracting the best pair until the input list is
  252. exhausted. The best pair is found by minimizing the tuple
  253. ``(-prod(indices_removed), cost)``. What this amounts to is prioritizing
  254. matrix multiplication or inner product operations, then Hadamard like
  255. operations, and finally outer operations. Outer products are limited by
  256. ``memory_limit``. This algorithm scales cubically with respect to the
  257. number of elements in the list ``input_sets``.
  258. Parameters
  259. ----------
  260. input_sets : list
  261. List of sets that represent the lhs side of the einsum subscript
  262. output_set : set
  263. Set that represents the rhs side of the overall einsum subscript
  264. idx_dict : dictionary
  265. Dictionary of index sizes
  266. memory_limit_limit : int
  267. The maximum number of elements in a temporary array
  268. Returns
  269. -------
  270. path : list
  271. The greedy contraction order within the memory limit constraint.
  272. Examples
  273. --------
  274. >>> isets = [set('abd'), set('ac'), set('bdc')]
  275. >>> oset = set()
  276. >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
  277. >>> _greedy_path(isets, oset, idx_sizes, 5000)
  278. [(0, 2), (0, 1)]
  279. """
  280. # Handle trivial cases that leaked through
  281. if len(input_sets) == 1:
  282. return [(0,)]
  283. elif len(input_sets) == 2:
  284. return [(0, 1)]
  285. # Build up a naive cost
  286. contract = _find_contraction(range(len(input_sets)), input_sets, output_set)
  287. idx_result, new_input_sets, idx_removed, idx_contract = contract
  288. naive_cost = _flop_count(idx_contract, idx_removed, len(input_sets), idx_dict)
  289. # Initially iterate over all pairs
  290. comb_iter = itertools.combinations(range(len(input_sets)), 2)
  291. known_contractions = []
  292. path_cost = 0
  293. path = []
  294. for iteration in range(len(input_sets) - 1):
  295. # Iterate over all pairs on first step, only previously found pairs on subsequent steps
  296. for positions in comb_iter:
  297. # Always initially ignore outer products
  298. if input_sets[positions[0]].isdisjoint(input_sets[positions[1]]):
  299. continue
  300. result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost,
  301. naive_cost)
  302. if result is not None:
  303. known_contractions.append(result)
  304. # If we do not have a inner contraction, rescan pairs including outer products
  305. if len(known_contractions) == 0:
  306. # Then check the outer products
  307. for positions in itertools.combinations(range(len(input_sets)), 2):
  308. result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit,
  309. path_cost, naive_cost)
  310. if result is not None:
  311. known_contractions.append(result)
  312. # If we still did not find any remaining contractions, default back to einsum like behavior
  313. if len(known_contractions) == 0:
  314. path.append(tuple(range(len(input_sets))))
  315. break
  316. # Sort based on first index
  317. best = min(known_contractions, key=lambda x: x[0])
  318. # Now propagate as many unused contractions as possible to next iteration
  319. known_contractions = _update_other_results(known_contractions, best)
  320. # Next iteration only compute contractions with the new tensor
  321. # All other contractions have been accounted for
  322. input_sets = best[2]
  323. new_tensor_pos = len(input_sets) - 1
  324. comb_iter = ((i, new_tensor_pos) for i in range(new_tensor_pos))
  325. # Update path and total cost
  326. path.append(best[1])
  327. path_cost += best[0][1]
  328. return path
  329. def _can_dot(inputs, result, idx_removed):
  330. """
  331. Checks if we can use BLAS (np.tensordot) call and its beneficial to do so.
  332. Parameters
  333. ----------
  334. inputs : list of str
  335. Specifies the subscripts for summation.
  336. result : str
  337. Resulting summation.
  338. idx_removed : set
  339. Indices that are removed in the summation
  340. Returns
  341. -------
  342. type : bool
  343. Returns true if BLAS should and can be used, else False
  344. Notes
  345. -----
  346. If the operations is BLAS level 1 or 2 and is not already aligned
  347. we default back to einsum as the memory movement to copy is more
  348. costly than the operation itself.
  349. Examples
  350. --------
  351. # Standard GEMM operation
  352. >>> _can_dot(['ij', 'jk'], 'ik', set('j'))
  353. True
  354. # Can use the standard BLAS, but requires odd data movement
  355. >>> _can_dot(['ijj', 'jk'], 'ik', set('j'))
  356. False
  357. # DDOT where the memory is not aligned
  358. >>> _can_dot(['ijk', 'ikj'], '', set('ijk'))
  359. False
  360. """
  361. # All `dot` calls remove indices
  362. if len(idx_removed) == 0:
  363. return False
  364. # BLAS can only handle two operands
  365. if len(inputs) != 2:
  366. return False
  367. input_left, input_right = inputs
  368. for c in set(input_left + input_right):
  369. # can't deal with repeated indices on same input or more than 2 total
  370. nl, nr = input_left.count(c), input_right.count(c)
  371. if (nl > 1) or (nr > 1) or (nl + nr > 2):
  372. return False
  373. # can't do implicit summation or dimension collapse e.g.
  374. # "ab,bc->c" (implicitly sum over 'a')
  375. # "ab,ca->ca" (take diagonal of 'a')
  376. if nl + nr - 1 == int(c in result):
  377. return False
  378. # Build a few temporaries
  379. set_left = set(input_left)
  380. set_right = set(input_right)
  381. keep_left = set_left - idx_removed
  382. keep_right = set_right - idx_removed
  383. rs = len(idx_removed)
  384. # At this point we are a DOT, GEMV, or GEMM operation
  385. # Handle inner products
  386. # DDOT with aligned data
  387. if input_left == input_right:
  388. return True
  389. # DDOT without aligned data (better to use einsum)
  390. if set_left == set_right:
  391. return False
  392. # Handle the 4 possible (aligned) GEMV or GEMM cases
  393. # GEMM or GEMV no transpose
  394. if input_left[-rs:] == input_right[:rs]:
  395. return True
  396. # GEMM or GEMV transpose both
  397. if input_left[:rs] == input_right[-rs:]:
  398. return True
  399. # GEMM or GEMV transpose right
  400. if input_left[-rs:] == input_right[-rs:]:
  401. return True
  402. # GEMM or GEMV transpose left
  403. if input_left[:rs] == input_right[:rs]:
  404. return True
  405. # Einsum is faster than GEMV if we have to copy data
  406. if not keep_left or not keep_right:
  407. return False
  408. # We are a matrix-matrix product, but we need to copy data
  409. return True
  410. def _parse_einsum_input(operands):
  411. """
  412. A reproduction of einsum c side einsum parsing in python.
  413. Returns
  414. -------
  415. input_strings : str
  416. Parsed input strings
  417. output_string : str
  418. Parsed output string
  419. operands : list of array_like
  420. The operands to use in the numpy contraction
  421. Examples
  422. --------
  423. The operand list is simplified to reduce printing:
  424. >>> np.random.seed(123)
  425. >>> a = np.random.rand(4, 4)
  426. >>> b = np.random.rand(4, 4, 4)
  427. >>> _parse_einsum_input(('...a,...a->...', a, b))
  428. ('za,xza', 'xz', [a, b]) # may vary
  429. >>> _parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))
  430. ('za,xza', 'xz', [a, b]) # may vary
  431. """
  432. if len(operands) == 0:
  433. raise ValueError("No input operands")
  434. if isinstance(operands[0], str):
  435. subscripts = operands[0].replace(" ", "")
  436. operands = [asanyarray(v) for v in operands[1:]]
  437. # Ensure all characters are valid
  438. for s in subscripts:
  439. if s in '.,->':
  440. continue
  441. if s not in einsum_symbols:
  442. raise ValueError("Character %s is not a valid symbol." % s)
  443. else:
  444. tmp_operands = list(operands)
  445. operand_list = []
  446. subscript_list = []
  447. for p in range(len(operands) // 2):
  448. operand_list.append(tmp_operands.pop(0))
  449. subscript_list.append(tmp_operands.pop(0))
  450. output_list = tmp_operands[-1] if len(tmp_operands) else None
  451. operands = [asanyarray(v) for v in operand_list]
  452. subscripts = ""
  453. last = len(subscript_list) - 1
  454. for num, sub in enumerate(subscript_list):
  455. for s in sub:
  456. if s is Ellipsis:
  457. subscripts += "..."
  458. else:
  459. try:
  460. s = operator.index(s)
  461. except TypeError as e:
  462. raise TypeError("For this input type lists must contain "
  463. "either int or Ellipsis") from e
  464. subscripts += einsum_symbols[s]
  465. if num != last:
  466. subscripts += ","
  467. if output_list is not None:
  468. subscripts += "->"
  469. for s in output_list:
  470. if s is Ellipsis:
  471. subscripts += "..."
  472. else:
  473. try:
  474. s = operator.index(s)
  475. except TypeError as e:
  476. raise TypeError("For this input type lists must contain "
  477. "either int or Ellipsis") from e
  478. subscripts += einsum_symbols[s]
  479. # Check for proper "->"
  480. if ("-" in subscripts) or (">" in subscripts):
  481. invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
  482. if invalid or (subscripts.count("->") != 1):
  483. raise ValueError("Subscripts can only contain one '->'.")
  484. # Parse ellipses
  485. if "." in subscripts:
  486. used = subscripts.replace(".", "").replace(",", "").replace("->", "")
  487. unused = list(einsum_symbols_set - set(used))
  488. ellipse_inds = "".join(unused)
  489. longest = 0
  490. if "->" in subscripts:
  491. input_tmp, output_sub = subscripts.split("->")
  492. split_subscripts = input_tmp.split(",")
  493. out_sub = True
  494. else:
  495. split_subscripts = subscripts.split(',')
  496. out_sub = False
  497. for num, sub in enumerate(split_subscripts):
  498. if "." in sub:
  499. if (sub.count(".") != 3) or (sub.count("...") != 1):
  500. raise ValueError("Invalid Ellipses.")
  501. # Take into account numerical values
  502. if operands[num].shape == ():
  503. ellipse_count = 0
  504. else:
  505. ellipse_count = max(operands[num].ndim, 1)
  506. ellipse_count -= (len(sub) - 3)
  507. if ellipse_count > longest:
  508. longest = ellipse_count
  509. if ellipse_count < 0:
  510. raise ValueError("Ellipses lengths do not match.")
  511. elif ellipse_count == 0:
  512. split_subscripts[num] = sub.replace('...', '')
  513. else:
  514. rep_inds = ellipse_inds[-ellipse_count:]
  515. split_subscripts[num] = sub.replace('...', rep_inds)
  516. subscripts = ",".join(split_subscripts)
  517. if longest == 0:
  518. out_ellipse = ""
  519. else:
  520. out_ellipse = ellipse_inds[-longest:]
  521. if out_sub:
  522. subscripts += "->" + output_sub.replace("...", out_ellipse)
  523. else:
  524. # Special care for outputless ellipses
  525. output_subscript = ""
  526. tmp_subscripts = subscripts.replace(",", "")
  527. for s in sorted(set(tmp_subscripts)):
  528. if s not in (einsum_symbols):
  529. raise ValueError("Character %s is not a valid symbol." % s)
  530. if tmp_subscripts.count(s) == 1:
  531. output_subscript += s
  532. normal_inds = ''.join(sorted(set(output_subscript) -
  533. set(out_ellipse)))
  534. subscripts += "->" + out_ellipse + normal_inds
  535. # Build output string if does not exist
  536. if "->" in subscripts:
  537. input_subscripts, output_subscript = subscripts.split("->")
  538. else:
  539. input_subscripts = subscripts
  540. # Build output subscripts
  541. tmp_subscripts = subscripts.replace(",", "")
  542. output_subscript = ""
  543. for s in sorted(set(tmp_subscripts)):
  544. if s not in einsum_symbols:
  545. raise ValueError("Character %s is not a valid symbol." % s)
  546. if tmp_subscripts.count(s) == 1:
  547. output_subscript += s
  548. # Make sure output subscripts are in the input
  549. for char in output_subscript:
  550. if char not in input_subscripts:
  551. raise ValueError("Output character %s did not appear in the input"
  552. % char)
  553. # Make sure number operands is equivalent to the number of terms
  554. if len(input_subscripts.split(',')) != len(operands):
  555. raise ValueError("Number of einsum subscripts must be equal to the "
  556. "number of operands.")
  557. return (input_subscripts, output_subscript, operands)
  558. def _einsum_path_dispatcher(*operands, optimize=None, einsum_call=None):
  559. # NOTE: technically, we should only dispatch on array-like arguments, not
  560. # subscripts (given as strings). But separating operands into
  561. # arrays/subscripts is a little tricky/slow (given einsum's two supported
  562. # signatures), so as a practical shortcut we dispatch on everything.
  563. # Strings will be ignored for dispatching since they don't define
  564. # __array_function__.
  565. return operands
  566. @array_function_dispatch(_einsum_path_dispatcher, module='numpy')
  567. def einsum_path(*operands, optimize='greedy', einsum_call=False):
  568. """
  569. einsum_path(subscripts, *operands, optimize='greedy')
  570. Evaluates the lowest cost contraction order for an einsum expression by
  571. considering the creation of intermediate arrays.
  572. Parameters
  573. ----------
  574. subscripts : str
  575. Specifies the subscripts for summation.
  576. *operands : list of array_like
  577. These are the arrays for the operation.
  578. optimize : {bool, list, tuple, 'greedy', 'optimal'}
  579. Choose the type of path. If a tuple is provided, the second argument is
  580. assumed to be the maximum intermediate size created. If only a single
  581. argument is provided the largest input or output array size is used
  582. as a maximum intermediate size.
  583. * if a list is given that starts with ``einsum_path``, uses this as the
  584. contraction path
  585. * if False no optimization is taken
  586. * if True defaults to the 'greedy' algorithm
  587. * 'optimal' An algorithm that combinatorially explores all possible
  588. ways of contracting the listed tensors and choosest the least costly
  589. path. Scales exponentially with the number of terms in the
  590. contraction.
  591. * 'greedy' An algorithm that chooses the best pair contraction
  592. at each step. Effectively, this algorithm searches the largest inner,
  593. Hadamard, and then outer products at each step. Scales cubically with
  594. the number of terms in the contraction. Equivalent to the 'optimal'
  595. path for most contractions.
  596. Default is 'greedy'.
  597. Returns
  598. -------
  599. path : list of tuples
  600. A list representation of the einsum path.
  601. string_repr : str
  602. A printable representation of the einsum path.
  603. Notes
  604. -----
  605. The resulting path indicates which terms of the input contraction should be
  606. contracted first, the result of this contraction is then appended to the
  607. end of the contraction list. This list can then be iterated over until all
  608. intermediate contractions are complete.
  609. See Also
  610. --------
  611. einsum, linalg.multi_dot
  612. Examples
  613. --------
  614. We can begin with a chain dot example. In this case, it is optimal to
  615. contract the ``b`` and ``c`` tensors first as represented by the first
  616. element of the path ``(1, 2)``. The resulting tensor is added to the end
  617. of the contraction and the remaining contraction ``(0, 1)`` is then
  618. completed.
  619. >>> np.random.seed(123)
  620. >>> a = np.random.rand(2, 2)
  621. >>> b = np.random.rand(2, 5)
  622. >>> c = np.random.rand(5, 2)
  623. >>> path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy')
  624. >>> print(path_info[0])
  625. ['einsum_path', (1, 2), (0, 1)]
  626. >>> print(path_info[1])
  627. Complete contraction: ij,jk,kl->il # may vary
  628. Naive scaling: 4
  629. Optimized scaling: 3
  630. Naive FLOP count: 1.600e+02
  631. Optimized FLOP count: 5.600e+01
  632. Theoretical speedup: 2.857
  633. Largest intermediate: 4.000e+00 elements
  634. -------------------------------------------------------------------------
  635. scaling current remaining
  636. -------------------------------------------------------------------------
  637. 3 kl,jk->jl ij,jl->il
  638. 3 jl,ij->il il->il
  639. A more complex index transformation example.
  640. >>> I = np.random.rand(10, 10, 10, 10)
  641. >>> C = np.random.rand(10, 10)
  642. >>> path_info = np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C,
  643. ... optimize='greedy')
  644. >>> print(path_info[0])
  645. ['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)]
  646. >>> print(path_info[1])
  647. Complete contraction: ea,fb,abcd,gc,hd->efgh # may vary
  648. Naive scaling: 8
  649. Optimized scaling: 5
  650. Naive FLOP count: 8.000e+08
  651. Optimized FLOP count: 8.000e+05
  652. Theoretical speedup: 1000.000
  653. Largest intermediate: 1.000e+04 elements
  654. --------------------------------------------------------------------------
  655. scaling current remaining
  656. --------------------------------------------------------------------------
  657. 5 abcd,ea->bcde fb,gc,hd,bcde->efgh
  658. 5 bcde,fb->cdef gc,hd,cdef->efgh
  659. 5 cdef,gc->defg hd,defg->efgh
  660. 5 defg,hd->efgh efgh->efgh
  661. """
  662. # Figure out what the path really is
  663. path_type = optimize
  664. if path_type is True:
  665. path_type = 'greedy'
  666. if path_type is None:
  667. path_type = False
  668. memory_limit = None
  669. # No optimization or a named path algorithm
  670. if (path_type is False) or isinstance(path_type, str):
  671. pass
  672. # Given an explicit path
  673. elif len(path_type) and (path_type[0] == 'einsum_path'):
  674. pass
  675. # Path tuple with memory limit
  676. elif ((len(path_type) == 2) and isinstance(path_type[0], str) and
  677. isinstance(path_type[1], (int, float))):
  678. memory_limit = int(path_type[1])
  679. path_type = path_type[0]
  680. else:
  681. raise TypeError("Did not understand the path: %s" % str(path_type))
  682. # Hidden option, only einsum should call this
  683. einsum_call_arg = einsum_call
  684. # Python side parsing
  685. input_subscripts, output_subscript, operands = _parse_einsum_input(operands)
  686. # Build a few useful list and sets
  687. input_list = input_subscripts.split(',')
  688. input_sets = [set(x) for x in input_list]
  689. output_set = set(output_subscript)
  690. indices = set(input_subscripts.replace(',', ''))
  691. # Get length of each unique dimension and ensure all dimensions are correct
  692. dimension_dict = {}
  693. broadcast_indices = [[] for x in range(len(input_list))]
  694. for tnum, term in enumerate(input_list):
  695. sh = operands[tnum].shape
  696. if len(sh) != len(term):
  697. raise ValueError("Einstein sum subscript %s does not contain the "
  698. "correct number of indices for operand %d."
  699. % (input_subscripts[tnum], tnum))
  700. for cnum, char in enumerate(term):
  701. dim = sh[cnum]
  702. # Build out broadcast indices
  703. if dim == 1:
  704. broadcast_indices[tnum].append(char)
  705. if char in dimension_dict.keys():
  706. # For broadcasting cases we always want the largest dim size
  707. if dimension_dict[char] == 1:
  708. dimension_dict[char] = dim
  709. elif dim not in (1, dimension_dict[char]):
  710. raise ValueError("Size of label '%s' for operand %d (%d) "
  711. "does not match previous terms (%d)."
  712. % (char, tnum, dimension_dict[char], dim))
  713. else:
  714. dimension_dict[char] = dim
  715. # Convert broadcast inds to sets
  716. broadcast_indices = [set(x) for x in broadcast_indices]
  717. # Compute size of each input array plus the output array
  718. size_list = [_compute_size_by_dict(term, dimension_dict)
  719. for term in input_list + [output_subscript]]
  720. max_size = max(size_list)
  721. if memory_limit is None:
  722. memory_arg = max_size
  723. else:
  724. memory_arg = memory_limit
  725. # Compute naive cost
  726. # This isn't quite right, need to look into exactly how einsum does this
  727. inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0
  728. naive_cost = _flop_count(indices, inner_product, len(input_list), dimension_dict)
  729. # Compute the path
  730. if (path_type is False) or (len(input_list) in [1, 2]) or (indices == output_set):
  731. # Nothing to be optimized, leave it to einsum
  732. path = [tuple(range(len(input_list)))]
  733. elif path_type == "greedy":
  734. path = _greedy_path(input_sets, output_set, dimension_dict, memory_arg)
  735. elif path_type == "optimal":
  736. path = _optimal_path(input_sets, output_set, dimension_dict, memory_arg)
  737. elif path_type[0] == 'einsum_path':
  738. path = path_type[1:]
  739. else:
  740. raise KeyError("Path name %s not found", path_type)
  741. cost_list, scale_list, size_list, contraction_list = [], [], [], []
  742. # Build contraction tuple (positions, gemm, einsum_str, remaining)
  743. for cnum, contract_inds in enumerate(path):
  744. # Make sure we remove inds from right to left
  745. contract_inds = tuple(sorted(list(contract_inds), reverse=True))
  746. contract = _find_contraction(contract_inds, input_sets, output_set)
  747. out_inds, input_sets, idx_removed, idx_contract = contract
  748. cost = _flop_count(idx_contract, idx_removed, len(contract_inds), dimension_dict)
  749. cost_list.append(cost)
  750. scale_list.append(len(idx_contract))
  751. size_list.append(_compute_size_by_dict(out_inds, dimension_dict))
  752. bcast = set()
  753. tmp_inputs = []
  754. for x in contract_inds:
  755. tmp_inputs.append(input_list.pop(x))
  756. bcast |= broadcast_indices.pop(x)
  757. new_bcast_inds = bcast - idx_removed
  758. # If we're broadcasting, nix blas
  759. if not len(idx_removed & bcast):
  760. do_blas = _can_dot(tmp_inputs, out_inds, idx_removed)
  761. else:
  762. do_blas = False
  763. # Last contraction
  764. if (cnum - len(path)) == -1:
  765. idx_result = output_subscript
  766. else:
  767. sort_result = [(dimension_dict[ind], ind) for ind in out_inds]
  768. idx_result = "".join([x[1] for x in sorted(sort_result)])
  769. input_list.append(idx_result)
  770. broadcast_indices.append(new_bcast_inds)
  771. einsum_str = ",".join(tmp_inputs) + "->" + idx_result
  772. contraction = (contract_inds, idx_removed, einsum_str, input_list[:], do_blas)
  773. contraction_list.append(contraction)
  774. opt_cost = sum(cost_list) + 1
  775. if einsum_call_arg:
  776. return (operands, contraction_list)
  777. # Return the path along with a nice string representation
  778. overall_contraction = input_subscripts + "->" + output_subscript
  779. header = ("scaling", "current", "remaining")
  780. speedup = naive_cost / opt_cost
  781. max_i = max(size_list)
  782. path_print = " Complete contraction: %s\n" % overall_contraction
  783. path_print += " Naive scaling: %d\n" % len(indices)
  784. path_print += " Optimized scaling: %d\n" % max(scale_list)
  785. path_print += " Naive FLOP count: %.3e\n" % naive_cost
  786. path_print += " Optimized FLOP count: %.3e\n" % opt_cost
  787. path_print += " Theoretical speedup: %3.3f\n" % speedup
  788. path_print += " Largest intermediate: %.3e elements\n" % max_i
  789. path_print += "-" * 74 + "\n"
  790. path_print += "%6s %24s %40s\n" % header
  791. path_print += "-" * 74
  792. for n, contraction in enumerate(contraction_list):
  793. inds, idx_rm, einsum_str, remaining, blas = contraction
  794. remaining_str = ",".join(remaining) + "->" + output_subscript
  795. path_run = (scale_list[n], einsum_str, remaining_str)
  796. path_print += "\n%4d %24s %40s" % path_run
  797. path = ['einsum_path'] + path
  798. return (path, path_print)
  799. def _einsum_dispatcher(*operands, out=None, optimize=None, **kwargs):
  800. # Arguably we dispatch on more arguments that we really should; see note in
  801. # _einsum_path_dispatcher for why.
  802. yield from operands
  803. yield out
  804. # Rewrite einsum to handle different cases
  805. @array_function_dispatch(_einsum_dispatcher, module='numpy')
  806. def einsum(*operands, out=None, optimize=False, **kwargs):
  807. """
  808. einsum(subscripts, *operands, out=None, dtype=None, order='K',
  809. casting='safe', optimize=False)
  810. Evaluates the Einstein summation convention on the operands.
  811. Using the Einstein summation convention, many common multi-dimensional,
  812. linear algebraic array operations can be represented in a simple fashion.
  813. In *implicit* mode `einsum` computes these values.
  814. In *explicit* mode, `einsum` provides further flexibility to compute
  815. other array operations that might not be considered classical Einstein
  816. summation operations, by disabling, or forcing summation over specified
  817. subscript labels.
  818. See the notes and examples for clarification.
  819. Parameters
  820. ----------
  821. subscripts : str
  822. Specifies the subscripts for summation as comma separated list of
  823. subscript labels. An implicit (classical Einstein summation)
  824. calculation is performed unless the explicit indicator '->' is
  825. included as well as subscript labels of the precise output form.
  826. operands : list of array_like
  827. These are the arrays for the operation.
  828. out : ndarray, optional
  829. If provided, the calculation is done into this array.
  830. dtype : {data-type, None}, optional
  831. If provided, forces the calculation to use the data type specified.
  832. Note that you may have to also give a more liberal `casting`
  833. parameter to allow the conversions. Default is None.
  834. order : {'C', 'F', 'A', 'K'}, optional
  835. Controls the memory layout of the output. 'C' means it should
  836. be C contiguous. 'F' means it should be Fortran contiguous,
  837. 'A' means it should be 'F' if the inputs are all 'F', 'C' otherwise.
  838. 'K' means it should be as close to the layout as the inputs as
  839. is possible, including arbitrarily permuted axes.
  840. Default is 'K'.
  841. casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
  842. Controls what kind of data casting may occur. Setting this to
  843. 'unsafe' is not recommended, as it can adversely affect accumulations.
  844. * 'no' means the data types should not be cast at all.
  845. * 'equiv' means only byte-order changes are allowed.
  846. * 'safe' means only casts which can preserve values are allowed.
  847. * 'same_kind' means only safe casts or casts within a kind,
  848. like float64 to float32, are allowed.
  849. * 'unsafe' means any data conversions may be done.
  850. Default is 'safe'.
  851. optimize : {False, True, 'greedy', 'optimal'}, optional
  852. Controls if intermediate optimization should occur. No optimization
  853. will occur if False and True will default to the 'greedy' algorithm.
  854. Also accepts an explicit contraction list from the ``np.einsum_path``
  855. function. See ``np.einsum_path`` for more details. Defaults to False.
  856. Returns
  857. -------
  858. output : ndarray
  859. The calculation based on the Einstein summation convention.
  860. See Also
  861. --------
  862. einsum_path, dot, inner, outer, tensordot, linalg.multi_dot
  863. Notes
  864. -----
  865. .. versionadded:: 1.6.0
  866. The Einstein summation convention can be used to compute
  867. many multi-dimensional, linear algebraic array operations. `einsum`
  868. provides a succinct way of representing these.
  869. A non-exhaustive list of these operations,
  870. which can be computed by `einsum`, is shown below along with examples:
  871. * Trace of an array, :py:func:`numpy.trace`.
  872. * Return a diagonal, :py:func:`numpy.diag`.
  873. * Array axis summations, :py:func:`numpy.sum`.
  874. * Transpositions and permutations, :py:func:`numpy.transpose`.
  875. * Matrix multiplication and dot product, :py:func:`numpy.matmul` :py:func:`numpy.dot`.
  876. * Vector inner and outer products, :py:func:`numpy.inner` :py:func:`numpy.outer`.
  877. * Broadcasting, element-wise and scalar multiplication, :py:func:`numpy.multiply`.
  878. * Tensor contractions, :py:func:`numpy.tensordot`.
  879. * Chained array operations, in efficient calculation order, :py:func:`numpy.einsum_path`.
  880. The subscripts string is a comma-separated list of subscript labels,
  881. where each label refers to a dimension of the corresponding operand.
  882. Whenever a label is repeated it is summed, so ``np.einsum('i,i', a, b)``
  883. is equivalent to :py:func:`np.inner(a,b) <numpy.inner>`. If a label
  884. appears only once, it is not summed, so ``np.einsum('i', a)`` produces a
  885. view of ``a`` with no changes. A further example ``np.einsum('ij,jk', a, b)``
  886. describes traditional matrix multiplication and is equivalent to
  887. :py:func:`np.matmul(a,b) <numpy.matmul>`. Repeated subscript labels in one
  888. operand take the diagonal. For example, ``np.einsum('ii', a)`` is equivalent
  889. to :py:func:`np.trace(a) <numpy.trace>`.
  890. In *implicit mode*, the chosen subscripts are important
  891. since the axes of the output are reordered alphabetically. This
  892. means that ``np.einsum('ij', a)`` doesn't affect a 2D array, while
  893. ``np.einsum('ji', a)`` takes its transpose. Additionally,
  894. ``np.einsum('ij,jk', a, b)`` returns a matrix multiplication, while,
  895. ``np.einsum('ij,jh', a, b)`` returns the transpose of the
  896. multiplication since subscript 'h' precedes subscript 'i'.
  897. In *explicit mode* the output can be directly controlled by
  898. specifying output subscript labels. This requires the
  899. identifier '->' as well as the list of output subscript labels.
  900. This feature increases the flexibility of the function since
  901. summing can be disabled or forced when required. The call
  902. ``np.einsum('i->', a)`` is like :py:func:`np.sum(a, axis=-1) <numpy.sum>`,
  903. and ``np.einsum('ii->i', a)`` is like :py:func:`np.diag(a) <numpy.diag>`.
  904. The difference is that `einsum` does not allow broadcasting by default.
  905. Additionally ``np.einsum('ij,jh->ih', a, b)`` directly specifies the
  906. order of the output subscript labels and therefore returns matrix
  907. multiplication, unlike the example above in implicit mode.
  908. To enable and control broadcasting, use an ellipsis. Default
  909. NumPy-style broadcasting is done by adding an ellipsis
  910. to the left of each term, like ``np.einsum('...ii->...i', a)``.
  911. To take the trace along the first and last axes,
  912. you can do ``np.einsum('i...i', a)``, or to do a matrix-matrix
  913. product with the left-most indices instead of rightmost, one can do
  914. ``np.einsum('ij...,jk...->ik...', a, b)``.
  915. When there is only one operand, no axes are summed, and no output
  916. parameter is provided, a view into the operand is returned instead
  917. of a new array. Thus, taking the diagonal as ``np.einsum('ii->i', a)``
  918. produces a view (changed in version 1.10.0).
  919. `einsum` also provides an alternative way to provide the subscripts
  920. and operands as ``einsum(op0, sublist0, op1, sublist1, ..., [sublistout])``.
  921. If the output shape is not provided in this format `einsum` will be
  922. calculated in implicit mode, otherwise it will be performed explicitly.
  923. The examples below have corresponding `einsum` calls with the two
  924. parameter methods.
  925. .. versionadded:: 1.10.0
  926. Views returned from einsum are now writeable whenever the input array
  927. is writeable. For example, ``np.einsum('ijk...->kji...', a)`` will now
  928. have the same effect as :py:func:`np.swapaxes(a, 0, 2) <numpy.swapaxes>`
  929. and ``np.einsum('ii->i', a)`` will return a writeable view of the diagonal
  930. of a 2D array.
  931. .. versionadded:: 1.12.0
  932. Added the ``optimize`` argument which will optimize the contraction order
  933. of an einsum expression. For a contraction with three or more operands this
  934. can greatly increase the computational efficiency at the cost of a larger
  935. memory footprint during computation.
  936. Typically a 'greedy' algorithm is applied which empirical tests have shown
  937. returns the optimal path in the majority of cases. In some cases 'optimal'
  938. will return the superlative path through a more expensive, exhaustive search.
  939. For iterative calculations it may be advisable to calculate the optimal path
  940. once and reuse that path by supplying it as an argument. An example is given
  941. below.
  942. See :py:func:`numpy.einsum_path` for more details.
  943. Examples
  944. --------
  945. >>> a = np.arange(25).reshape(5,5)
  946. >>> b = np.arange(5)
  947. >>> c = np.arange(6).reshape(2,3)
  948. Trace of a matrix:
  949. >>> np.einsum('ii', a)
  950. 60
  951. >>> np.einsum(a, [0,0])
  952. 60
  953. >>> np.trace(a)
  954. 60
  955. Extract the diagonal (requires explicit form):
  956. >>> np.einsum('ii->i', a)
  957. array([ 0, 6, 12, 18, 24])
  958. >>> np.einsum(a, [0,0], [0])
  959. array([ 0, 6, 12, 18, 24])
  960. >>> np.diag(a)
  961. array([ 0, 6, 12, 18, 24])
  962. Sum over an axis (requires explicit form):
  963. >>> np.einsum('ij->i', a)
  964. array([ 10, 35, 60, 85, 110])
  965. >>> np.einsum(a, [0,1], [0])
  966. array([ 10, 35, 60, 85, 110])
  967. >>> np.sum(a, axis=1)
  968. array([ 10, 35, 60, 85, 110])
  969. For higher dimensional arrays summing a single axis can be done with ellipsis:
  970. >>> np.einsum('...j->...', a)
  971. array([ 10, 35, 60, 85, 110])
  972. >>> np.einsum(a, [Ellipsis,1], [Ellipsis])
  973. array([ 10, 35, 60, 85, 110])
  974. Compute a matrix transpose, or reorder any number of axes:
  975. >>> np.einsum('ji', c)
  976. array([[0, 3],
  977. [1, 4],
  978. [2, 5]])
  979. >>> np.einsum('ij->ji', c)
  980. array([[0, 3],
  981. [1, 4],
  982. [2, 5]])
  983. >>> np.einsum(c, [1,0])
  984. array([[0, 3],
  985. [1, 4],
  986. [2, 5]])
  987. >>> np.transpose(c)
  988. array([[0, 3],
  989. [1, 4],
  990. [2, 5]])
  991. Vector inner products:
  992. >>> np.einsum('i,i', b, b)
  993. 30
  994. >>> np.einsum(b, [0], b, [0])
  995. 30
  996. >>> np.inner(b,b)
  997. 30
  998. Matrix vector multiplication:
  999. >>> np.einsum('ij,j', a, b)
  1000. array([ 30, 80, 130, 180, 230])
  1001. >>> np.einsum(a, [0,1], b, [1])
  1002. array([ 30, 80, 130, 180, 230])
  1003. >>> np.dot(a, b)
  1004. array([ 30, 80, 130, 180, 230])
  1005. >>> np.einsum('...j,j', a, b)
  1006. array([ 30, 80, 130, 180, 230])
  1007. Broadcasting and scalar multiplication:
  1008. >>> np.einsum('..., ...', 3, c)
  1009. array([[ 0, 3, 6],
  1010. [ 9, 12, 15]])
  1011. >>> np.einsum(',ij', 3, c)
  1012. array([[ 0, 3, 6],
  1013. [ 9, 12, 15]])
  1014. >>> np.einsum(3, [Ellipsis], c, [Ellipsis])
  1015. array([[ 0, 3, 6],
  1016. [ 9, 12, 15]])
  1017. >>> np.multiply(3, c)
  1018. array([[ 0, 3, 6],
  1019. [ 9, 12, 15]])
  1020. Vector outer product:
  1021. >>> np.einsum('i,j', np.arange(2)+1, b)
  1022. array([[0, 1, 2, 3, 4],
  1023. [0, 2, 4, 6, 8]])
  1024. >>> np.einsum(np.arange(2)+1, [0], b, [1])
  1025. array([[0, 1, 2, 3, 4],
  1026. [0, 2, 4, 6, 8]])
  1027. >>> np.outer(np.arange(2)+1, b)
  1028. array([[0, 1, 2, 3, 4],
  1029. [0, 2, 4, 6, 8]])
  1030. Tensor contraction:
  1031. >>> a = np.arange(60.).reshape(3,4,5)
  1032. >>> b = np.arange(24.).reshape(4,3,2)
  1033. >>> np.einsum('ijk,jil->kl', a, b)
  1034. array([[4400., 4730.],
  1035. [4532., 4874.],
  1036. [4664., 5018.],
  1037. [4796., 5162.],
  1038. [4928., 5306.]])
  1039. >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3])
  1040. array([[4400., 4730.],
  1041. [4532., 4874.],
  1042. [4664., 5018.],
  1043. [4796., 5162.],
  1044. [4928., 5306.]])
  1045. >>> np.tensordot(a,b, axes=([1,0],[0,1]))
  1046. array([[4400., 4730.],
  1047. [4532., 4874.],
  1048. [4664., 5018.],
  1049. [4796., 5162.],
  1050. [4928., 5306.]])
  1051. Writeable returned arrays (since version 1.10.0):
  1052. >>> a = np.zeros((3, 3))
  1053. >>> np.einsum('ii->i', a)[:] = 1
  1054. >>> a
  1055. array([[1., 0., 0.],
  1056. [0., 1., 0.],
  1057. [0., 0., 1.]])
  1058. Example of ellipsis use:
  1059. >>> a = np.arange(6).reshape((3,2))
  1060. >>> b = np.arange(12).reshape((4,3))
  1061. >>> np.einsum('ki,jk->ij', a, b)
  1062. array([[10, 28, 46, 64],
  1063. [13, 40, 67, 94]])
  1064. >>> np.einsum('ki,...k->i...', a, b)
  1065. array([[10, 28, 46, 64],
  1066. [13, 40, 67, 94]])
  1067. >>> np.einsum('k...,jk', a, b)
  1068. array([[10, 28, 46, 64],
  1069. [13, 40, 67, 94]])
  1070. Chained array operations. For more complicated contractions, speed ups
  1071. might be achieved by repeatedly computing a 'greedy' path or pre-computing the
  1072. 'optimal' path and repeatedly applying it, using an
  1073. `einsum_path` insertion (since version 1.12.0). Performance improvements can be
  1074. particularly significant with larger arrays:
  1075. >>> a = np.ones(64).reshape(2,4,8)
  1076. Basic `einsum`: ~1520ms (benchmarked on 3.1GHz Intel i5.)
  1077. >>> for iteration in range(500):
  1078. ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a)
  1079. Sub-optimal `einsum` (due to repeated path calculation time): ~330ms
  1080. >>> for iteration in range(500):
  1081. ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')
  1082. Greedy `einsum` (faster optimal path approximation): ~160ms
  1083. >>> for iteration in range(500):
  1084. ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy')
  1085. Optimal `einsum` (best usage pattern in some use cases): ~110ms
  1086. >>> path = np.einsum_path('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')[0]
  1087. >>> for iteration in range(500):
  1088. ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=path)
  1089. """
  1090. # Special handling if out is specified
  1091. specified_out = out is not None
  1092. # If no optimization, run pure einsum
  1093. if optimize is False:
  1094. if specified_out:
  1095. kwargs['out'] = out
  1096. return c_einsum(*operands, **kwargs)
  1097. # Check the kwargs to avoid a more cryptic error later, without having to
  1098. # repeat default values here
  1099. valid_einsum_kwargs = ['dtype', 'order', 'casting']
  1100. unknown_kwargs = [k for (k, v) in kwargs.items() if
  1101. k not in valid_einsum_kwargs]
  1102. if len(unknown_kwargs):
  1103. raise TypeError("Did not understand the following kwargs: %s"
  1104. % unknown_kwargs)
  1105. # Build the contraction list and operand
  1106. operands, contraction_list = einsum_path(*operands, optimize=optimize,
  1107. einsum_call=True)
  1108. # Start contraction loop
  1109. for num, contraction in enumerate(contraction_list):
  1110. inds, idx_rm, einsum_str, remaining, blas = contraction
  1111. tmp_operands = [operands.pop(x) for x in inds]
  1112. # Do we need to deal with the output?
  1113. handle_out = specified_out and ((num + 1) == len(contraction_list))
  1114. # Call tensordot if still possible
  1115. if blas:
  1116. # Checks have already been handled
  1117. input_str, results_index = einsum_str.split('->')
  1118. input_left, input_right = input_str.split(',')
  1119. tensor_result = input_left + input_right
  1120. for s in idx_rm:
  1121. tensor_result = tensor_result.replace(s, "")
  1122. # Find indices to contract over
  1123. left_pos, right_pos = [], []
  1124. for s in sorted(idx_rm):
  1125. left_pos.append(input_left.find(s))
  1126. right_pos.append(input_right.find(s))
  1127. # Contract!
  1128. new_view = tensordot(*tmp_operands, axes=(tuple(left_pos), tuple(right_pos)))
  1129. # Build a new view if needed
  1130. if (tensor_result != results_index) or handle_out:
  1131. if handle_out:
  1132. kwargs["out"] = out
  1133. new_view = c_einsum(tensor_result + '->' + results_index, new_view, **kwargs)
  1134. # Call einsum
  1135. else:
  1136. # If out was specified
  1137. if handle_out:
  1138. kwargs["out"] = out
  1139. # Do the contraction
  1140. new_view = c_einsum(einsum_str, *tmp_operands, **kwargs)
  1141. # Append new items and dereference what we can
  1142. operands.append(new_view)
  1143. del tmp_operands, new_view
  1144. if specified_out:
  1145. return out
  1146. else:
  1147. return operands[0]