test_collections.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662
  1. import io
  2. from types import SimpleNamespace
  3. import numpy as np
  4. from numpy.testing import assert_array_equal, assert_array_almost_equal
  5. import pytest
  6. import matplotlib as mpl
  7. import matplotlib.pyplot as plt
  8. import matplotlib.collections as mcollections
  9. import matplotlib.transforms as mtransforms
  10. from matplotlib.collections import (Collection, LineCollection,
  11. EventCollection, PolyCollection)
  12. from matplotlib.testing.decorators import image_comparison
  13. def generate_EventCollection_plot():
  14. """Generate the initial collection and plot it."""
  15. positions = np.array([0., 1., 2., 3., 5., 8., 13., 21.])
  16. extra_positions = np.array([34., 55., 89.])
  17. orientation = 'horizontal'
  18. lineoffset = 1
  19. linelength = .5
  20. linewidth = 2
  21. color = [1, 0, 0, 1]
  22. linestyle = 'solid'
  23. antialiased = True
  24. coll = EventCollection(positions,
  25. orientation=orientation,
  26. lineoffset=lineoffset,
  27. linelength=linelength,
  28. linewidth=linewidth,
  29. color=color,
  30. linestyle=linestyle,
  31. antialiased=antialiased
  32. )
  33. fig = plt.figure()
  34. ax = fig.add_subplot(1, 1, 1)
  35. ax.add_collection(coll)
  36. ax.set_title('EventCollection: default')
  37. props = {'positions': positions,
  38. 'extra_positions': extra_positions,
  39. 'orientation': orientation,
  40. 'lineoffset': lineoffset,
  41. 'linelength': linelength,
  42. 'linewidth': linewidth,
  43. 'color': color,
  44. 'linestyle': linestyle,
  45. 'antialiased': antialiased
  46. }
  47. ax.set_xlim(-1, 22)
  48. ax.set_ylim(0, 2)
  49. return ax, coll, props
  50. @image_comparison(['EventCollection_plot__default'])
  51. def test__EventCollection__get_props():
  52. _, coll, props = generate_EventCollection_plot()
  53. # check that the default segments have the correct coordinates
  54. check_segments(coll,
  55. props['positions'],
  56. props['linelength'],
  57. props['lineoffset'],
  58. props['orientation'])
  59. # check that the default positions match the input positions
  60. np.testing.assert_array_equal(props['positions'], coll.get_positions())
  61. # check that the default orientation matches the input orientation
  62. assert props['orientation'] == coll.get_orientation()
  63. # check that the default orientation matches the input orientation
  64. assert coll.is_horizontal()
  65. # check that the default linelength matches the input linelength
  66. assert props['linelength'] == coll.get_linelength()
  67. # check that the default lineoffset matches the input lineoffset
  68. assert props['lineoffset'] == coll.get_lineoffset()
  69. # check that the default linestyle matches the input linestyle
  70. assert coll.get_linestyle() == [(0, None)]
  71. # check that the default color matches the input color
  72. for color in [coll.get_color(), *coll.get_colors()]:
  73. np.testing.assert_array_equal(color, props['color'])
  74. @image_comparison(['EventCollection_plot__set_positions'])
  75. def test__EventCollection__set_positions():
  76. splt, coll, props = generate_EventCollection_plot()
  77. new_positions = np.hstack([props['positions'], props['extra_positions']])
  78. coll.set_positions(new_positions)
  79. np.testing.assert_array_equal(new_positions, coll.get_positions())
  80. check_segments(coll, new_positions,
  81. props['linelength'],
  82. props['lineoffset'],
  83. props['orientation'])
  84. splt.set_title('EventCollection: set_positions')
  85. splt.set_xlim(-1, 90)
  86. @image_comparison(['EventCollection_plot__add_positions'])
  87. def test__EventCollection__add_positions():
  88. splt, coll, props = generate_EventCollection_plot()
  89. new_positions = np.hstack([props['positions'],
  90. props['extra_positions'][0]])
  91. coll.switch_orientation() # Test adding in the vertical orientation, too.
  92. coll.add_positions(props['extra_positions'][0])
  93. coll.switch_orientation()
  94. np.testing.assert_array_equal(new_positions, coll.get_positions())
  95. check_segments(coll,
  96. new_positions,
  97. props['linelength'],
  98. props['lineoffset'],
  99. props['orientation'])
  100. splt.set_title('EventCollection: add_positions')
  101. splt.set_xlim(-1, 35)
  102. @image_comparison(['EventCollection_plot__append_positions'])
  103. def test__EventCollection__append_positions():
  104. splt, coll, props = generate_EventCollection_plot()
  105. new_positions = np.hstack([props['positions'],
  106. props['extra_positions'][2]])
  107. coll.append_positions(props['extra_positions'][2])
  108. np.testing.assert_array_equal(new_positions, coll.get_positions())
  109. check_segments(coll,
  110. new_positions,
  111. props['linelength'],
  112. props['lineoffset'],
  113. props['orientation'])
  114. splt.set_title('EventCollection: append_positions')
  115. splt.set_xlim(-1, 90)
  116. @image_comparison(['EventCollection_plot__extend_positions'])
  117. def test__EventCollection__extend_positions():
  118. splt, coll, props = generate_EventCollection_plot()
  119. new_positions = np.hstack([props['positions'],
  120. props['extra_positions'][1:]])
  121. coll.extend_positions(props['extra_positions'][1:])
  122. np.testing.assert_array_equal(new_positions, coll.get_positions())
  123. check_segments(coll,
  124. new_positions,
  125. props['linelength'],
  126. props['lineoffset'],
  127. props['orientation'])
  128. splt.set_title('EventCollection: extend_positions')
  129. splt.set_xlim(-1, 90)
  130. @image_comparison(['EventCollection_plot__switch_orientation'])
  131. def test__EventCollection__switch_orientation():
  132. splt, coll, props = generate_EventCollection_plot()
  133. new_orientation = 'vertical'
  134. coll.switch_orientation()
  135. assert new_orientation == coll.get_orientation()
  136. assert not coll.is_horizontal()
  137. new_positions = coll.get_positions()
  138. check_segments(coll,
  139. new_positions,
  140. props['linelength'],
  141. props['lineoffset'], new_orientation)
  142. splt.set_title('EventCollection: switch_orientation')
  143. splt.set_ylim(-1, 22)
  144. splt.set_xlim(0, 2)
  145. @image_comparison(['EventCollection_plot__switch_orientation__2x'])
  146. def test__EventCollection__switch_orientation_2x():
  147. """
  148. Check that calling switch_orientation twice sets the orientation back to
  149. the default.
  150. """
  151. splt, coll, props = generate_EventCollection_plot()
  152. coll.switch_orientation()
  153. coll.switch_orientation()
  154. new_positions = coll.get_positions()
  155. assert props['orientation'] == coll.get_orientation()
  156. assert coll.is_horizontal()
  157. np.testing.assert_array_equal(props['positions'], new_positions)
  158. check_segments(coll,
  159. new_positions,
  160. props['linelength'],
  161. props['lineoffset'],
  162. props['orientation'])
  163. splt.set_title('EventCollection: switch_orientation 2x')
  164. @image_comparison(['EventCollection_plot__set_orientation'])
  165. def test__EventCollection__set_orientation():
  166. splt, coll, props = generate_EventCollection_plot()
  167. new_orientation = 'vertical'
  168. coll.set_orientation(new_orientation)
  169. assert new_orientation == coll.get_orientation()
  170. assert not coll.is_horizontal()
  171. check_segments(coll,
  172. props['positions'],
  173. props['linelength'],
  174. props['lineoffset'],
  175. new_orientation)
  176. splt.set_title('EventCollection: set_orientation')
  177. splt.set_ylim(-1, 22)
  178. splt.set_xlim(0, 2)
  179. @image_comparison(['EventCollection_plot__set_linelength'])
  180. def test__EventCollection__set_linelength():
  181. splt, coll, props = generate_EventCollection_plot()
  182. new_linelength = 15
  183. coll.set_linelength(new_linelength)
  184. assert new_linelength == coll.get_linelength()
  185. check_segments(coll,
  186. props['positions'],
  187. new_linelength,
  188. props['lineoffset'],
  189. props['orientation'])
  190. splt.set_title('EventCollection: set_linelength')
  191. splt.set_ylim(-20, 20)
  192. @image_comparison(['EventCollection_plot__set_lineoffset'])
  193. def test__EventCollection__set_lineoffset():
  194. splt, coll, props = generate_EventCollection_plot()
  195. new_lineoffset = -5.
  196. coll.set_lineoffset(new_lineoffset)
  197. assert new_lineoffset == coll.get_lineoffset()
  198. check_segments(coll,
  199. props['positions'],
  200. props['linelength'],
  201. new_lineoffset,
  202. props['orientation'])
  203. splt.set_title('EventCollection: set_lineoffset')
  204. splt.set_ylim(-6, -4)
  205. @image_comparison([
  206. 'EventCollection_plot__set_linestyle',
  207. 'EventCollection_plot__set_linestyle',
  208. 'EventCollection_plot__set_linewidth',
  209. ])
  210. def test__EventCollection__set_prop():
  211. for prop, value, expected in [
  212. ('linestyle', 'dashed', [(0, (6.0, 6.0))]),
  213. ('linestyle', (0, (6., 6.)), [(0, (6.0, 6.0))]),
  214. ('linewidth', 5, 5),
  215. ]:
  216. splt, coll, _ = generate_EventCollection_plot()
  217. coll.set(**{prop: value})
  218. assert plt.getp(coll, prop) == expected
  219. splt.set_title(f'EventCollection: set_{prop}')
  220. @image_comparison(['EventCollection_plot__set_color'])
  221. def test__EventCollection__set_color():
  222. splt, coll, _ = generate_EventCollection_plot()
  223. new_color = np.array([0, 1, 1, 1])
  224. coll.set_color(new_color)
  225. for color in [coll.get_color(), *coll.get_colors()]:
  226. np.testing.assert_array_equal(color, new_color)
  227. splt.set_title('EventCollection: set_color')
  228. def check_segments(coll, positions, linelength, lineoffset, orientation):
  229. """
  230. Test helper checking that all values in the segment are correct, given a
  231. particular set of inputs.
  232. """
  233. segments = coll.get_segments()
  234. if (orientation.lower() == 'horizontal'
  235. or orientation.lower() == 'none' or orientation is None):
  236. # if horizontal, the position in is in the y-axis
  237. pos1 = 1
  238. pos2 = 0
  239. elif orientation.lower() == 'vertical':
  240. # if vertical, the position in is in the x-axis
  241. pos1 = 0
  242. pos2 = 1
  243. else:
  244. raise ValueError("orientation must be 'horizontal' or 'vertical'")
  245. # test to make sure each segment is correct
  246. for i, segment in enumerate(segments):
  247. assert segment[0, pos1] == lineoffset + linelength / 2
  248. assert segment[1, pos1] == lineoffset - linelength / 2
  249. assert segment[0, pos2] == positions[i]
  250. assert segment[1, pos2] == positions[i]
  251. def test_null_collection_datalim():
  252. col = mcollections.PathCollection([])
  253. col_data_lim = col.get_datalim(mtransforms.IdentityTransform())
  254. assert_array_equal(col_data_lim.get_points(),
  255. mtransforms.Bbox.null().get_points())
  256. def test_add_collection():
  257. # Test if data limits are unchanged by adding an empty collection.
  258. # GitHub issue #1490, pull #1497.
  259. plt.figure()
  260. ax = plt.axes()
  261. coll = ax.scatter([0, 1], [0, 1])
  262. ax.add_collection(coll)
  263. bounds = ax.dataLim.bounds
  264. coll = ax.scatter([], [])
  265. assert ax.dataLim.bounds == bounds
  266. def test_quiver_limits():
  267. ax = plt.axes()
  268. x, y = np.arange(8), np.arange(10)
  269. u = v = np.linspace(0, 10, 80).reshape(10, 8)
  270. q = plt.quiver(x, y, u, v)
  271. assert q.get_datalim(ax.transData).bounds == (0., 0., 7., 9.)
  272. plt.figure()
  273. ax = plt.axes()
  274. x = np.linspace(-5, 10, 20)
  275. y = np.linspace(-2, 4, 10)
  276. y, x = np.meshgrid(y, x)
  277. trans = mtransforms.Affine2D().translate(25, 32) + ax.transData
  278. plt.quiver(x, y, np.sin(x), np.cos(y), transform=trans)
  279. assert ax.dataLim.bounds == (20.0, 30.0, 15.0, 6.0)
  280. def test_barb_limits():
  281. ax = plt.axes()
  282. x = np.linspace(-5, 10, 20)
  283. y = np.linspace(-2, 4, 10)
  284. y, x = np.meshgrid(y, x)
  285. trans = mtransforms.Affine2D().translate(25, 32) + ax.transData
  286. plt.barbs(x, y, np.sin(x), np.cos(y), transform=trans)
  287. # The calculated bounds are approximately the bounds of the original data,
  288. # this is because the entire path is taken into account when updating the
  289. # datalim.
  290. assert_array_almost_equal(ax.dataLim.bounds, (20, 30, 15, 6),
  291. decimal=1)
  292. @image_comparison(['EllipseCollection_test_image.png'], remove_text=True)
  293. def test_EllipseCollection():
  294. # Test basic functionality
  295. fig, ax = plt.subplots()
  296. x = np.arange(4)
  297. y = np.arange(3)
  298. X, Y = np.meshgrid(x, y)
  299. XY = np.vstack((X.ravel(), Y.ravel())).T
  300. ww = X / x[-1]
  301. hh = Y / y[-1]
  302. aa = np.ones_like(ww) * 20 # first axis is 20 degrees CCW from x axis
  303. ec = mcollections.EllipseCollection(ww, hh, aa,
  304. units='x',
  305. offsets=XY,
  306. transOffset=ax.transData,
  307. facecolors='none')
  308. ax.add_collection(ec)
  309. ax.autoscale_view()
  310. @image_comparison(['polycollection_close.png'], remove_text=True)
  311. def test_polycollection_close():
  312. from mpl_toolkits.mplot3d import Axes3D
  313. vertsQuad = [
  314. [[0., 0.], [0., 1.], [1., 1.], [1., 0.]],
  315. [[0., 1.], [2., 3.], [2., 2.], [1., 1.]],
  316. [[2., 2.], [2., 3.], [4., 1.], [3., 1.]],
  317. [[3., 0.], [3., 1.], [4., 1.], [4., 0.]]]
  318. fig = plt.figure()
  319. ax = Axes3D(fig)
  320. colors = ['r', 'g', 'b', 'y', 'k']
  321. zpos = list(range(5))
  322. poly = mcollections.PolyCollection(
  323. vertsQuad * len(zpos), linewidth=0.25)
  324. poly.set_alpha(0.7)
  325. # need to have a z-value for *each* polygon = element!
  326. zs = []
  327. cs = []
  328. for z, c in zip(zpos, colors):
  329. zs.extend([z] * len(vertsQuad))
  330. cs.extend([c] * len(vertsQuad))
  331. poly.set_color(cs)
  332. ax.add_collection3d(poly, zs=zs, zdir='y')
  333. # axis limit settings:
  334. ax.set_xlim3d(0, 4)
  335. ax.set_zlim3d(0, 3)
  336. ax.set_ylim3d(0, 4)
  337. @image_comparison(['regularpolycollection_rotate.png'], remove_text=True)
  338. def test_regularpolycollection_rotate():
  339. xx, yy = np.mgrid[:10, :10]
  340. xy_points = np.transpose([xx.flatten(), yy.flatten()])
  341. rotations = np.linspace(0, 2*np.pi, len(xy_points))
  342. fig, ax = plt.subplots()
  343. for xy, alpha in zip(xy_points, rotations):
  344. col = mcollections.RegularPolyCollection(
  345. 4, sizes=(100,), rotation=alpha,
  346. offsets=[xy], transOffset=ax.transData)
  347. ax.add_collection(col, autolim=True)
  348. ax.autoscale_view()
  349. @image_comparison(['regularpolycollection_scale.png'], remove_text=True)
  350. def test_regularpolycollection_scale():
  351. # See issue #3860
  352. class SquareCollection(mcollections.RegularPolyCollection):
  353. def __init__(self, **kwargs):
  354. super().__init__(4, rotation=np.pi/4., **kwargs)
  355. def get_transform(self):
  356. """Return transform scaling circle areas to data space."""
  357. ax = self.axes
  358. pts2pixels = 72.0 / ax.figure.dpi
  359. scale_x = pts2pixels * ax.bbox.width / ax.viewLim.width
  360. scale_y = pts2pixels * ax.bbox.height / ax.viewLim.height
  361. return mtransforms.Affine2D().scale(scale_x, scale_y)
  362. fig, ax = plt.subplots()
  363. xy = [(0, 0)]
  364. # Unit square has a half-diagonal of `1/sqrt(2)`, so `pi * r**2` equals...
  365. circle_areas = [np.pi / 2]
  366. squares = SquareCollection(sizes=circle_areas, offsets=xy,
  367. transOffset=ax.transData)
  368. ax.add_collection(squares, autolim=True)
  369. ax.axis([-1, 1, -1, 1])
  370. def test_picking():
  371. fig, ax = plt.subplots()
  372. col = ax.scatter([0], [0], [1000], picker=True)
  373. fig.savefig(io.BytesIO(), dpi=fig.dpi)
  374. mouse_event = SimpleNamespace(x=325, y=240)
  375. found, indices = col.contains(mouse_event)
  376. assert found
  377. assert_array_equal(indices['ind'], [0])
  378. def test_linestyle_single_dashes():
  379. plt.scatter([0, 1, 2], [0, 1, 2], linestyle=(0., [2., 2.]))
  380. plt.draw()
  381. @image_comparison(['size_in_xy.png'], remove_text=True)
  382. def test_size_in_xy():
  383. fig, ax = plt.subplots()
  384. widths, heights, angles = (10, 10), 10, 0
  385. widths = 10, 10
  386. coords = [(10, 10), (15, 15)]
  387. e = mcollections.EllipseCollection(
  388. widths, heights, angles,
  389. units='xy',
  390. offsets=coords,
  391. transOffset=ax.transData)
  392. ax.add_collection(e)
  393. ax.set_xlim(0, 30)
  394. ax.set_ylim(0, 30)
  395. def test_pandas_indexing(pd):
  396. # Should not fail break when faced with a
  397. # non-zero indexed series
  398. index = [11, 12, 13]
  399. ec = fc = pd.Series(['red', 'blue', 'green'], index=index)
  400. lw = pd.Series([1, 2, 3], index=index)
  401. ls = pd.Series(['solid', 'dashed', 'dashdot'], index=index)
  402. aa = pd.Series([True, False, True], index=index)
  403. Collection(edgecolors=ec)
  404. Collection(facecolors=fc)
  405. Collection(linewidths=lw)
  406. Collection(linestyles=ls)
  407. Collection(antialiaseds=aa)
  408. @pytest.mark.style('default')
  409. def test_lslw_bcast():
  410. col = mcollections.PathCollection([])
  411. col.set_linestyles(['-', '-'])
  412. col.set_linewidths([1, 2, 3])
  413. assert col.get_linestyles() == [(0, None)] * 6
  414. assert col.get_linewidths() == [1, 2, 3] * 2
  415. col.set_linestyles(['-', '-', '-'])
  416. assert col.get_linestyles() == [(0, None)] * 3
  417. assert (col.get_linewidths() == [1, 2, 3]).all()
  418. @pytest.mark.style('default')
  419. def test_capstyle():
  420. col = mcollections.PathCollection([], capstyle='round')
  421. assert col.get_capstyle() == 'round'
  422. col.set_capstyle('butt')
  423. assert col.get_capstyle() == 'butt'
  424. @pytest.mark.style('default')
  425. def test_joinstyle():
  426. col = mcollections.PathCollection([], joinstyle='round')
  427. assert col.get_joinstyle() == 'round'
  428. col.set_joinstyle('miter')
  429. assert col.get_joinstyle() == 'miter'
  430. @image_comparison(['cap_and_joinstyle.png'])
  431. def test_cap_and_joinstyle_image():
  432. fig = plt.figure()
  433. ax = fig.add_subplot(1, 1, 1)
  434. ax.set_xlim([-0.5, 1.5])
  435. ax.set_ylim([-0.5, 2.5])
  436. x = np.array([0.0, 1.0, 0.5])
  437. ys = np.array([[0.0], [0.5], [1.0]]) + np.array([[0.0, 0.0, 1.0]])
  438. segs = np.zeros((3, 3, 2))
  439. segs[:, :, 0] = x
  440. segs[:, :, 1] = ys
  441. line_segments = LineCollection(segs, linewidth=[10, 15, 20])
  442. line_segments.set_capstyle("round")
  443. line_segments.set_joinstyle("miter")
  444. ax.add_collection(line_segments)
  445. ax.set_title('Line collection with customized caps and joinstyle')
  446. @image_comparison(['scatter_post_alpha.png'],
  447. remove_text=True, style='default')
  448. def test_scatter_post_alpha():
  449. fig, ax = plt.subplots()
  450. sc = ax.scatter(range(5), range(5), c=range(5))
  451. # this needs to be here to update internal state
  452. fig.canvas.draw()
  453. sc.set_alpha(.1)
  454. def test_pathcollection_legend_elements():
  455. np.random.seed(19680801)
  456. x, y = np.random.rand(2, 10)
  457. y = np.random.rand(10)
  458. c = np.random.randint(0, 5, size=10)
  459. s = np.random.randint(10, 300, size=10)
  460. fig, ax = plt.subplots()
  461. sc = ax.scatter(x, y, c=c, s=s, cmap="jet", marker="o", linewidths=0)
  462. h, l = sc.legend_elements(fmt="{x:g}")
  463. assert len(h) == 5
  464. assert_array_equal(np.array(l).astype(float), np.arange(5))
  465. colors = np.array([line.get_color() for line in h])
  466. colors2 = sc.cmap(np.arange(5)/4)
  467. assert_array_equal(colors, colors2)
  468. l1 = ax.legend(h, l, loc=1)
  469. h2, lab2 = sc.legend_elements(num=9)
  470. assert len(h2) == 9
  471. l2 = ax.legend(h2, lab2, loc=2)
  472. h, l = sc.legend_elements(prop="sizes", alpha=0.5, color="red")
  473. alpha = np.array([line.get_alpha() for line in h])
  474. assert_array_equal(alpha, 0.5)
  475. color = np.array([line.get_markerfacecolor() for line in h])
  476. assert_array_equal(color, "red")
  477. l3 = ax.legend(h, l, loc=4)
  478. h, l = sc.legend_elements(prop="sizes", num=4, fmt="{x:.2f}",
  479. func=lambda x: 2*x)
  480. actsizes = [line.get_markersize() for line in h]
  481. labeledsizes = np.sqrt(np.array(l).astype(float)/2)
  482. assert_array_almost_equal(actsizes, labeledsizes)
  483. l4 = ax.legend(h, l, loc=3)
  484. loc = mpl.ticker.MaxNLocator(nbins=9, min_n_ticks=9-1,
  485. steps=[1, 2, 2.5, 3, 5, 6, 8, 10])
  486. h5, lab5 = sc.legend_elements(num=loc)
  487. assert len(h2) == len(h5)
  488. levels = [-1, 0, 55.4, 260]
  489. h6, lab6 = sc.legend_elements(num=levels, prop="sizes", fmt="{x:g}")
  490. assert_array_equal(np.array(lab6).astype(float), levels[2:])
  491. for l in [l1, l2, l3, l4]:
  492. ax.add_artist(l)
  493. fig.canvas.draw()
  494. def test_EventCollection_nosort():
  495. # Check that EventCollection doesn't modify input in place
  496. arr = np.array([3, 2, 1, 10])
  497. coll = EventCollection(arr)
  498. np.testing.assert_array_equal(arr, np.array([3, 2, 1, 10]))
  499. def test_collection_set_verts_array():
  500. verts = np.arange(80, dtype=np.double).reshape(10, 4, 2)
  501. col_arr = PolyCollection(verts)
  502. col_list = PolyCollection(list(verts))
  503. assert len(col_arr._paths) == len(col_list._paths)
  504. for ap, lp in zip(col_arr._paths, col_list._paths):
  505. assert np.array_equal(ap._vertices, lp._vertices)
  506. assert np.array_equal(ap._codes, lp._codes)
  507. verts_tuple = np.empty(10, dtype=object)
  508. verts_tuple[:] = [tuple(tuple(y) for y in x) for x in verts]
  509. col_arr_tuple = PolyCollection(verts_tuple)
  510. assert len(col_arr._paths) == len(col_arr_tuple._paths)
  511. for ap, atp in zip(col_arr._paths, col_arr_tuple._paths):
  512. assert np.array_equal(ap._vertices, atp._vertices)
  513. assert np.array_equal(ap._codes, atp._codes)
  514. def test_blended_collection_autolim():
  515. a = [1, 2, 4]
  516. height = .2
  517. xy_pairs = np.column_stack([np.repeat(a, 2), np.tile([0, height], len(a))])
  518. line_segs = xy_pairs.reshape([len(a), 2, 2])
  519. f, ax = plt.subplots()
  520. trans = mtransforms.blended_transform_factory(ax.transData, ax.transAxes)
  521. ax.add_collection(LineCollection(line_segs, transform=trans))
  522. ax.autoscale_view(scalex=True, scaley=False)
  523. np.testing.assert_allclose(ax.get_xlim(), [1., 4.])
  524. def test_singleton_autolim():
  525. fig, ax = plt.subplots()
  526. ax.scatter(0, 0)
  527. np.testing.assert_allclose(ax.get_ylim(), [-0.06, 0.06])
  528. np.testing.assert_allclose(ax.get_xlim(), [-0.06, 0.06])
  529. def test_quadmesh_set_array():
  530. x = np.arange(4)
  531. y = np.arange(4)
  532. z = np.arange(9).reshape((3, 3))
  533. fig, ax = plt.subplots()
  534. coll = ax.pcolormesh(x, y, np.ones(z.shape))
  535. # Test that the collection is able to update with a 2d array
  536. coll.set_array(z)
  537. fig.canvas.draw()
  538. assert np.array_equal(coll.get_array(), z)
  539. # Check that pre-flattened arrays work too
  540. coll.set_array(np.ones(9))
  541. fig.canvas.draw()
  542. assert np.array_equal(coll.get_array(), np.ones(9))