dual.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. """
  2. Aliases for functions which may be accelerated by Scipy.
  3. Scipy_ can be built to use accelerated or otherwise improved libraries
  4. for FFTs, linear algebra, and special functions. This module allows
  5. developers to transparently support these accelerated functions when
  6. scipy is available but still support users who have only installed
  7. NumPy.
  8. .. _Scipy : https://www.scipy.org
  9. """
  10. # This module should be used for functions both in numpy and scipy if
  11. # you want to use the numpy version if available but the scipy version
  12. # otherwise.
  13. # Usage --- from numpy.dual import fft, inv
  14. __all__ = ['fft', 'ifft', 'fftn', 'ifftn', 'fft2', 'ifft2',
  15. 'norm', 'inv', 'svd', 'solve', 'det', 'eig', 'eigvals',
  16. 'eigh', 'eigvalsh', 'lstsq', 'pinv', 'cholesky', 'i0']
  17. import numpy.linalg as linpkg
  18. import numpy.fft as fftpkg
  19. from numpy.lib import i0
  20. import sys
  21. fft = fftpkg.fft
  22. ifft = fftpkg.ifft
  23. fftn = fftpkg.fftn
  24. ifftn = fftpkg.ifftn
  25. fft2 = fftpkg.fft2
  26. ifft2 = fftpkg.ifft2
  27. norm = linpkg.norm
  28. inv = linpkg.inv
  29. svd = linpkg.svd
  30. solve = linpkg.solve
  31. det = linpkg.det
  32. eig = linpkg.eig
  33. eigvals = linpkg.eigvals
  34. eigh = linpkg.eigh
  35. eigvalsh = linpkg.eigvalsh
  36. lstsq = linpkg.lstsq
  37. pinv = linpkg.pinv
  38. cholesky = linpkg.cholesky
  39. _restore_dict = {}
  40. def register_func(name, func):
  41. if name not in __all__:
  42. raise ValueError("{} not a dual function.".format(name))
  43. f = sys._getframe(0).f_globals
  44. _restore_dict[name] = f[name]
  45. f[name] = func
  46. def restore_func(name):
  47. if name not in __all__:
  48. raise ValueError("{} not a dual function.".format(name))
  49. try:
  50. val = _restore_dict[name]
  51. except KeyError:
  52. return
  53. else:
  54. sys._getframe(0).f_globals[name] = val
  55. def restore_all():
  56. for name in _restore_dict.keys():
  57. restore_func(name)