test_item_selection.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import sys
  2. import numpy as np
  3. from numpy.testing import (
  4. assert_, assert_raises, assert_array_equal, HAS_REFCOUNT
  5. )
  6. class TestTake:
  7. def test_simple(self):
  8. a = [[1, 2], [3, 4]]
  9. a_str = [[b'1', b'2'], [b'3', b'4']]
  10. modes = ['raise', 'wrap', 'clip']
  11. indices = [-1, 4]
  12. index_arrays = [np.empty(0, dtype=np.intp),
  13. np.empty(tuple(), dtype=np.intp),
  14. np.empty((1, 1), dtype=np.intp)]
  15. real_indices = {'raise': {-1: 1, 4: IndexError},
  16. 'wrap': {-1: 1, 4: 0},
  17. 'clip': {-1: 0, 4: 1}}
  18. # Currently all types but object, use the same function generation.
  19. # So it should not be necessary to test all. However test also a non
  20. # refcounted struct on top of object, which has a size that hits the
  21. # default (non-specialized) path.
  22. types = int, object, np.dtype([('', 'i2', 3)])
  23. for t in types:
  24. # ta works, even if the array may be odd if buffer interface is used
  25. ta = np.array(a if np.issubdtype(t, np.number) else a_str, dtype=t)
  26. tresult = list(ta.T.copy())
  27. for index_array in index_arrays:
  28. if index_array.size != 0:
  29. tresult[0].shape = (2,) + index_array.shape
  30. tresult[1].shape = (2,) + index_array.shape
  31. for mode in modes:
  32. for index in indices:
  33. real_index = real_indices[mode][index]
  34. if real_index is IndexError and index_array.size != 0:
  35. index_array.put(0, index)
  36. assert_raises(IndexError, ta.take, index_array,
  37. mode=mode, axis=1)
  38. elif index_array.size != 0:
  39. index_array.put(0, index)
  40. res = ta.take(index_array, mode=mode, axis=1)
  41. assert_array_equal(res, tresult[real_index])
  42. else:
  43. res = ta.take(index_array, mode=mode, axis=1)
  44. assert_(res.shape == (2,) + index_array.shape)
  45. def test_refcounting(self):
  46. objects = [object() for i in range(10)]
  47. for mode in ('raise', 'clip', 'wrap'):
  48. a = np.array(objects)
  49. b = np.array([2, 2, 4, 5, 3, 5])
  50. a.take(b, out=a[:6], mode=mode)
  51. del a
  52. if HAS_REFCOUNT:
  53. assert_(all(sys.getrefcount(o) == 3 for o in objects))
  54. # not contiguous, example:
  55. a = np.array(objects * 2)[::2]
  56. a.take(b, out=a[:6], mode=mode)
  57. del a
  58. if HAS_REFCOUNT:
  59. assert_(all(sys.getrefcount(o) == 3 for o in objects))
  60. def test_unicode_mode(self):
  61. d = np.arange(10)
  62. k = b'\xc3\xa4'.decode("UTF8")
  63. assert_raises(ValueError, d.take, 5, mode=k)
  64. def test_empty_partition(self):
  65. # In reference to github issue #6530
  66. a_original = np.array([0, 2, 4, 6, 8, 10])
  67. a = a_original.copy()
  68. # An empty partition should be a successful no-op
  69. a.partition(np.array([], dtype=np.int16))
  70. assert_array_equal(a, a_original)
  71. def test_empty_argpartition(self):
  72. # In reference to github issue #6530
  73. a = np.array([0, 2, 4, 6, 8, 10])
  74. a = a.argpartition(np.array([], dtype=np.int16))
  75. b = np.array([0, 1, 2, 3, 4, 5])
  76. assert_array_equal(a, b)