dispatch.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. """.. _dispatch_mechanism:
  2. Numpy's dispatch mechanism, introduced in numpy version v1.16 is the
  3. recommended approach for writing custom N-dimensional array containers that are
  4. compatible with the numpy API and provide custom implementations of numpy
  5. functionality. Applications include `dask <http://dask.pydata.org>`_ arrays, an
  6. N-dimensional array distributed across multiple nodes, and `cupy
  7. <https://docs-cupy.chainer.org/en/stable/>`_ arrays, an N-dimensional array on
  8. a GPU.
  9. To get a feel for writing custom array containers, we'll begin with a simple
  10. example that has rather narrow utility but illustrates the concepts involved.
  11. >>> import numpy as np
  12. >>> class DiagonalArray:
  13. ... def __init__(self, N, value):
  14. ... self._N = N
  15. ... self._i = value
  16. ... def __repr__(self):
  17. ... return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
  18. ... def __array__(self):
  19. ... return self._i * np.eye(self._N)
  20. ...
  21. Our custom array can be instantiated like:
  22. >>> arr = DiagonalArray(5, 1)
  23. >>> arr
  24. DiagonalArray(N=5, value=1)
  25. We can convert to a numpy array using :func:`numpy.array` or
  26. :func:`numpy.asarray`, which will call its ``__array__`` method to obtain a
  27. standard ``numpy.ndarray``.
  28. >>> np.asarray(arr)
  29. array([[1., 0., 0., 0., 0.],
  30. [0., 1., 0., 0., 0.],
  31. [0., 0., 1., 0., 0.],
  32. [0., 0., 0., 1., 0.],
  33. [0., 0., 0., 0., 1.]])
  34. If we operate on ``arr`` with a numpy function, numpy will again use the
  35. ``__array__`` interface to convert it to an array and then apply the function
  36. in the usual way.
  37. >>> np.multiply(arr, 2)
  38. array([[2., 0., 0., 0., 0.],
  39. [0., 2., 0., 0., 0.],
  40. [0., 0., 2., 0., 0.],
  41. [0., 0., 0., 2., 0.],
  42. [0., 0., 0., 0., 2.]])
  43. Notice that the return type is a standard ``numpy.ndarray``.
  44. >>> type(arr)
  45. numpy.ndarray
  46. How can we pass our custom array type through this function? Numpy allows a
  47. class to indicate that it would like to handle computations in a custom-defined
  48. way through the interfaces ``__array_ufunc__`` and ``__array_function__``. Let's
  49. take one at a time, starting with ``_array_ufunc__``. This method covers
  50. :ref:`ufuncs`, a class of functions that includes, for example,
  51. :func:`numpy.multiply` and :func:`numpy.sin`.
  52. The ``__array_ufunc__`` receives:
  53. - ``ufunc``, a function like ``numpy.multiply``
  54. - ``method``, a string, differentiating between ``numpy.multiply(...)`` and
  55. variants like ``numpy.multiply.outer``, ``numpy.multiply.accumulate``, and so
  56. on. For the common case, ``numpy.multiply(...)``, ``method == '__call__'``.
  57. - ``inputs``, which could be a mixture of different types
  58. - ``kwargs``, keyword arguments passed to the function
  59. For this example we will only handle the method ``__call__``.
  60. >>> from numbers import Number
  61. >>> class DiagonalArray:
  62. ... def __init__(self, N, value):
  63. ... self._N = N
  64. ... self._i = value
  65. ... def __repr__(self):
  66. ... return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
  67. ... def __array__(self):
  68. ... return self._i * np.eye(self._N)
  69. ... def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
  70. ... if method == '__call__':
  71. ... N = None
  72. ... scalars = []
  73. ... for input in inputs:
  74. ... if isinstance(input, Number):
  75. ... scalars.append(input)
  76. ... elif isinstance(input, self.__class__):
  77. ... scalars.append(input._i)
  78. ... if N is not None:
  79. ... if N != self._N:
  80. ... raise TypeError("inconsistent sizes")
  81. ... else:
  82. ... N = self._N
  83. ... else:
  84. ... return NotImplemented
  85. ... return self.__class__(N, ufunc(*scalars, **kwargs))
  86. ... else:
  87. ... return NotImplemented
  88. ...
  89. Now our custom array type passes through numpy functions.
  90. >>> arr = DiagonalArray(5, 1)
  91. >>> np.multiply(arr, 3)
  92. DiagonalArray(N=5, value=3)
  93. >>> np.add(arr, 3)
  94. DiagonalArray(N=5, value=4)
  95. >>> np.sin(arr)
  96. DiagonalArray(N=5, value=0.8414709848078965)
  97. At this point ``arr + 3`` does not work.
  98. >>> arr + 3
  99. TypeError: unsupported operand type(s) for *: 'DiagonalArray' and 'int'
  100. To support it, we need to define the Python interfaces ``__add__``, ``__lt__``,
  101. and so on to dispatch to the corresponding ufunc. We can achieve this
  102. conveniently by inheriting from the mixin
  103. :class:`~numpy.lib.mixins.NDArrayOperatorsMixin`.
  104. >>> import numpy.lib.mixins
  105. >>> class DiagonalArray(numpy.lib.mixins.NDArrayOperatorsMixin):
  106. ... def __init__(self, N, value):
  107. ... self._N = N
  108. ... self._i = value
  109. ... def __repr__(self):
  110. ... return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
  111. ... def __array__(self):
  112. ... return self._i * np.eye(self._N)
  113. ... def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
  114. ... if method == '__call__':
  115. ... N = None
  116. ... scalars = []
  117. ... for input in inputs:
  118. ... if isinstance(input, Number):
  119. ... scalars.append(input)
  120. ... elif isinstance(input, self.__class__):
  121. ... scalars.append(input._i)
  122. ... if N is not None:
  123. ... if N != self._N:
  124. ... raise TypeError("inconsistent sizes")
  125. ... else:
  126. ... N = self._N
  127. ... else:
  128. ... return NotImplemented
  129. ... return self.__class__(N, ufunc(*scalars, **kwargs))
  130. ... else:
  131. ... return NotImplemented
  132. ...
  133. >>> arr = DiagonalArray(5, 1)
  134. >>> arr + 3
  135. DiagonalArray(N=5, value=4)
  136. >>> arr > 0
  137. DiagonalArray(N=5, value=True)
  138. Now let's tackle ``__array_function__``. We'll create dict that maps numpy
  139. functions to our custom variants.
  140. >>> HANDLED_FUNCTIONS = {}
  141. >>> class DiagonalArray(numpy.lib.mixins.NDArrayOperatorsMixin):
  142. ... def __init__(self, N, value):
  143. ... self._N = N
  144. ... self._i = value
  145. ... def __repr__(self):
  146. ... return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
  147. ... def __array__(self):
  148. ... return self._i * np.eye(self._N)
  149. ... def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
  150. ... if method == '__call__':
  151. ... N = None
  152. ... scalars = []
  153. ... for input in inputs:
  154. ... # In this case we accept only scalar numbers or DiagonalArrays.
  155. ... if isinstance(input, Number):
  156. ... scalars.append(input)
  157. ... elif isinstance(input, self.__class__):
  158. ... scalars.append(input._i)
  159. ... if N is not None:
  160. ... if N != self._N:
  161. ... raise TypeError("inconsistent sizes")
  162. ... else:
  163. ... N = self._N
  164. ... else:
  165. ... return NotImplemented
  166. ... return self.__class__(N, ufunc(*scalars, **kwargs))
  167. ... else:
  168. ... return NotImplemented
  169. ... def __array_function__(self, func, types, args, kwargs):
  170. ... if func not in HANDLED_FUNCTIONS:
  171. ... return NotImplemented
  172. ... # Note: this allows subclasses that don't override
  173. ... # __array_function__ to handle DiagonalArray objects.
  174. ... if not all(issubclass(t, self.__class__) for t in types):
  175. ... return NotImplemented
  176. ... return HANDLED_FUNCTIONS[func](*args, **kwargs)
  177. ...
  178. A convenient pattern is to define a decorator ``implements`` that can be used
  179. to add functions to ``HANDLED_FUNCTIONS``.
  180. >>> def implements(np_function):
  181. ... "Register an __array_function__ implementation for DiagonalArray objects."
  182. ... def decorator(func):
  183. ... HANDLED_FUNCTIONS[np_function] = func
  184. ... return func
  185. ... return decorator
  186. ...
  187. Now we write implementations of numpy functions for ``DiagonalArray``.
  188. For completeness, to support the usage ``arr.sum()`` add a method ``sum`` that
  189. calls ``numpy.sum(self)``, and the same for ``mean``.
  190. >>> @implements(np.sum)
  191. ... def sum(arr):
  192. ... "Implementation of np.sum for DiagonalArray objects"
  193. ... return arr._i * arr._N
  194. ...
  195. >>> @implements(np.mean)
  196. ... def mean(arr):
  197. ... "Implementation of np.mean for DiagonalArray objects"
  198. ... return arr._i / arr._N
  199. ...
  200. >>> arr = DiagonalArray(5, 1)
  201. >>> np.sum(arr)
  202. 5
  203. >>> np.mean(arr)
  204. 0.2
  205. If the user tries to use any numpy functions not included in
  206. ``HANDLED_FUNCTIONS``, a ``TypeError`` will be raised by numpy, indicating that
  207. this operation is not supported. For example, concatenating two
  208. ``DiagonalArrays`` does not produce another diagonal array, so it is not
  209. supported.
  210. >>> np.concatenate([arr, arr])
  211. TypeError: no implementation found for 'numpy.concatenate' on types that implement __array_function__: [<class '__main__.DiagonalArray'>]
  212. Additionally, our implementations of ``sum`` and ``mean`` do not accept the
  213. optional arguments that numpy's implementation does.
  214. >>> np.sum(arr, axis=0)
  215. TypeError: sum() got an unexpected keyword argument 'axis'
  216. The user always has the option of converting to a normal ``numpy.ndarray`` with
  217. :func:`numpy.asarray` and using standard numpy from there.
  218. >>> np.concatenate([np.asarray(arr), np.asarray(arr)])
  219. array([[1., 0., 0., 0., 0.],
  220. [0., 1., 0., 0., 0.],
  221. [0., 0., 1., 0., 0.],
  222. [0., 0., 0., 1., 0.],
  223. [0., 0., 0., 0., 1.],
  224. [1., 0., 0., 0., 0.],
  225. [0., 1., 0., 0., 0.],
  226. [0., 0., 1., 0., 0.],
  227. [0., 0., 0., 1., 0.],
  228. [0., 0., 0., 0., 1.]])
  229. Refer to the `dask source code <https://github.com/dask/dask>`_ and
  230. `cupy source code <https://github.com/cupy/cupy>`_ for more fully-worked
  231. examples of custom array containers.
  232. See also `NEP 18 <http://www.numpy.org/neps/nep-0018-array-function-protocol.html>`_.
  233. """