spines.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. import numpy as np
  2. import matplotlib
  3. from matplotlib import cbook, docstring, rcParams
  4. from matplotlib.artist import allow_rasterization
  5. import matplotlib.transforms as mtransforms
  6. import matplotlib.patches as mpatches
  7. import matplotlib.path as mpath
  8. class Spine(mpatches.Patch):
  9. """
  10. An axis spine -- the line noting the data area boundaries.
  11. Spines are the lines connecting the axis tick marks and noting the
  12. boundaries of the data area. They can be placed at arbitrary
  13. positions. See `~.Spine.set_position` for more information.
  14. The default position is ``('outward', 0)``.
  15. Spines are subclasses of `.Patch`, and inherit much of their behavior.
  16. Spines draw a line, a circle, or an arc depending if
  17. `~.Spine.set_patch_line`, `~.Spine.set_patch_circle`, or
  18. `~.Spine.set_patch_arc` has been called. Line-like is the default.
  19. """
  20. def __str__(self):
  21. return "Spine"
  22. @docstring.dedent_interpd
  23. def __init__(self, axes, spine_type, path, **kwargs):
  24. """
  25. Parameters
  26. ----------
  27. axes : `~matplotlib.axes.Axes`
  28. The `~.axes.Axes` instance containing the spine.
  29. spine_type : str
  30. The spine type.
  31. path : `~matplotlib.path.Path`
  32. The `.Path` instance used to draw the spine.
  33. Other Parameters
  34. ----------------
  35. **kwargs
  36. Valid keyword arguments are:
  37. %(Patch)s
  38. """
  39. super().__init__(**kwargs)
  40. self.axes = axes
  41. self.set_figure(self.axes.figure)
  42. self.spine_type = spine_type
  43. self.set_facecolor('none')
  44. self.set_edgecolor(rcParams['axes.edgecolor'])
  45. self.set_linewidth(rcParams['axes.linewidth'])
  46. self.set_capstyle('projecting')
  47. self.axis = None
  48. self.set_zorder(2.5)
  49. self.set_transform(self.axes.transData) # default transform
  50. self._bounds = None # default bounds
  51. self._smart_bounds = False # deprecated in 3.2
  52. # Defer initial position determination. (Not much support for
  53. # non-rectangular axes is currently implemented, and this lets
  54. # them pass through the spines machinery without errors.)
  55. self._position = None
  56. cbook._check_isinstance(matplotlib.path.Path, path=path)
  57. self._path = path
  58. # To support drawing both linear and circular spines, this
  59. # class implements Patch behavior three ways. If
  60. # self._patch_type == 'line', behave like a mpatches.PathPatch
  61. # instance. If self._patch_type == 'circle', behave like a
  62. # mpatches.Ellipse instance. If self._patch_type == 'arc', behave like
  63. # a mpatches.Arc instance.
  64. self._patch_type = 'line'
  65. # Behavior copied from mpatches.Ellipse:
  66. # Note: This cannot be calculated until this is added to an Axes
  67. self._patch_transform = mtransforms.IdentityTransform()
  68. @cbook.deprecated("3.2")
  69. def set_smart_bounds(self, value):
  70. """Set the spine and associated axis to have smart bounds."""
  71. self._smart_bounds = value
  72. # also set the axis if possible
  73. if self.spine_type in ('left', 'right'):
  74. self.axes.yaxis.set_smart_bounds(value)
  75. elif self.spine_type in ('top', 'bottom'):
  76. self.axes.xaxis.set_smart_bounds(value)
  77. self.stale = True
  78. @cbook.deprecated("3.2")
  79. def get_smart_bounds(self):
  80. """Return whether the spine has smart bounds."""
  81. return self._smart_bounds
  82. def set_patch_arc(self, center, radius, theta1, theta2):
  83. """Set the spine to be arc-like."""
  84. self._patch_type = 'arc'
  85. self._center = center
  86. self._width = radius * 2
  87. self._height = radius * 2
  88. self._theta1 = theta1
  89. self._theta2 = theta2
  90. self._path = mpath.Path.arc(theta1, theta2)
  91. # arc drawn on axes transform
  92. self.set_transform(self.axes.transAxes)
  93. self.stale = True
  94. def set_patch_circle(self, center, radius):
  95. """Set the spine to be circular."""
  96. self._patch_type = 'circle'
  97. self._center = center
  98. self._width = radius * 2
  99. self._height = radius * 2
  100. # circle drawn on axes transform
  101. self.set_transform(self.axes.transAxes)
  102. self.stale = True
  103. def set_patch_line(self):
  104. """Set the spine to be linear."""
  105. self._patch_type = 'line'
  106. self.stale = True
  107. # Behavior copied from mpatches.Ellipse:
  108. def _recompute_transform(self):
  109. """
  110. Notes
  111. -----
  112. This cannot be called until after this has been added to an Axes,
  113. otherwise unit conversion will fail. This makes it very important to
  114. call the accessor method and not directly access the transformation
  115. member variable.
  116. """
  117. assert self._patch_type in ('arc', 'circle')
  118. center = (self.convert_xunits(self._center[0]),
  119. self.convert_yunits(self._center[1]))
  120. width = self.convert_xunits(self._width)
  121. height = self.convert_yunits(self._height)
  122. self._patch_transform = mtransforms.Affine2D() \
  123. .scale(width * 0.5, height * 0.5) \
  124. .translate(*center)
  125. def get_patch_transform(self):
  126. if self._patch_type in ('arc', 'circle'):
  127. self._recompute_transform()
  128. return self._patch_transform
  129. else:
  130. return super().get_patch_transform()
  131. def get_window_extent(self, renderer=None):
  132. """
  133. Return the window extent of the spines in display space, including
  134. padding for ticks (but not their labels)
  135. See Also
  136. --------
  137. matplotlib.axes.Axes.get_tightbbox
  138. matplotlib.axes.Axes.get_window_extent
  139. """
  140. # make sure the location is updated so that transforms etc are correct:
  141. self._adjust_location()
  142. bb = super().get_window_extent(renderer=renderer)
  143. if self.axis is None:
  144. return bb
  145. bboxes = [bb]
  146. tickstocheck = [self.axis.majorTicks[0]]
  147. if len(self.axis.minorTicks) > 1:
  148. # only pad for minor ticks if there are more than one
  149. # of them. There is always one...
  150. tickstocheck.append(self.axis.minorTicks[1])
  151. for tick in tickstocheck:
  152. bb0 = bb.frozen()
  153. tickl = tick._size
  154. tickdir = tick._tickdir
  155. if tickdir == 'out':
  156. padout = 1
  157. padin = 0
  158. elif tickdir == 'in':
  159. padout = 0
  160. padin = 1
  161. else:
  162. padout = 0.5
  163. padin = 0.5
  164. padout = padout * tickl / 72 * self.figure.dpi
  165. padin = padin * tickl / 72 * self.figure.dpi
  166. if tick.tick1line.get_visible():
  167. if self.spine_type == 'left':
  168. bb0.x0 = bb0.x0 - padout
  169. bb0.x1 = bb0.x1 + padin
  170. elif self.spine_type == 'bottom':
  171. bb0.y0 = bb0.y0 - padout
  172. bb0.y1 = bb0.y1 + padin
  173. if tick.tick2line.get_visible():
  174. if self.spine_type == 'right':
  175. bb0.x1 = bb0.x1 + padout
  176. bb0.x0 = bb0.x0 - padin
  177. elif self.spine_type == 'top':
  178. bb0.y1 = bb0.y1 + padout
  179. bb0.y0 = bb0.y0 - padout
  180. bboxes.append(bb0)
  181. return mtransforms.Bbox.union(bboxes)
  182. def get_path(self):
  183. return self._path
  184. def _ensure_position_is_set(self):
  185. if self._position is None:
  186. # default position
  187. self._position = ('outward', 0.0) # in points
  188. self.set_position(self._position)
  189. def register_axis(self, axis):
  190. """
  191. Register an axis.
  192. An axis should be registered with its corresponding spine from
  193. the Axes instance. This allows the spine to clear any axis
  194. properties when needed.
  195. """
  196. self.axis = axis
  197. if self.axis is not None:
  198. self.axis.cla()
  199. self.stale = True
  200. def cla(self):
  201. """Clear the current spine."""
  202. self._position = None # clear position
  203. if self.axis is not None:
  204. self.axis.cla()
  205. def _adjust_location(self):
  206. """Automatically set spine bounds to the view interval."""
  207. if self.spine_type == 'circle':
  208. return
  209. if self._bounds is None:
  210. if self.spine_type in ('left', 'right'):
  211. low, high = self.axes.viewLim.intervaly
  212. elif self.spine_type in ('top', 'bottom'):
  213. low, high = self.axes.viewLim.intervalx
  214. else:
  215. raise ValueError('unknown spine spine_type: %s' %
  216. self.spine_type)
  217. if self._smart_bounds: # deprecated in 3.2
  218. # attempt to set bounds in sophisticated way
  219. # handle inverted limits
  220. viewlim_low, viewlim_high = sorted([low, high])
  221. if self.spine_type in ('left', 'right'):
  222. datalim_low, datalim_high = self.axes.dataLim.intervaly
  223. ticks = self.axes.get_yticks()
  224. elif self.spine_type in ('top', 'bottom'):
  225. datalim_low, datalim_high = self.axes.dataLim.intervalx
  226. ticks = self.axes.get_xticks()
  227. # handle inverted limits
  228. ticks = np.sort(ticks)
  229. datalim_low, datalim_high = sorted([datalim_low, datalim_high])
  230. if datalim_low < viewlim_low:
  231. # Data extends past view. Clip line to view.
  232. low = viewlim_low
  233. else:
  234. # Data ends before view ends.
  235. cond = (ticks <= datalim_low) & (ticks >= viewlim_low)
  236. tickvals = ticks[cond]
  237. if len(tickvals):
  238. # A tick is less than or equal to lowest data point.
  239. low = tickvals[-1]
  240. else:
  241. # No tick is available
  242. low = datalim_low
  243. low = max(low, viewlim_low)
  244. if datalim_high > viewlim_high:
  245. # Data extends past view. Clip line to view.
  246. high = viewlim_high
  247. else:
  248. # Data ends before view ends.
  249. cond = (ticks >= datalim_high) & (ticks <= viewlim_high)
  250. tickvals = ticks[cond]
  251. if len(tickvals):
  252. # A tick is greater than or equal to highest data
  253. # point.
  254. high = tickvals[0]
  255. else:
  256. # No tick is available
  257. high = datalim_high
  258. high = min(high, viewlim_high)
  259. else:
  260. low, high = self._bounds
  261. if self._patch_type == 'arc':
  262. if self.spine_type in ('bottom', 'top'):
  263. try:
  264. direction = self.axes.get_theta_direction()
  265. except AttributeError:
  266. direction = 1
  267. try:
  268. offset = self.axes.get_theta_offset()
  269. except AttributeError:
  270. offset = 0
  271. low = low * direction + offset
  272. high = high * direction + offset
  273. if low > high:
  274. low, high = high, low
  275. self._path = mpath.Path.arc(np.rad2deg(low), np.rad2deg(high))
  276. if self.spine_type == 'bottom':
  277. rmin, rmax = self.axes.viewLim.intervaly
  278. try:
  279. rorigin = self.axes.get_rorigin()
  280. except AttributeError:
  281. rorigin = rmin
  282. scaled_diameter = (rmin - rorigin) / (rmax - rorigin)
  283. self._height = scaled_diameter
  284. self._width = scaled_diameter
  285. else:
  286. raise ValueError('unable to set bounds for spine "%s"' %
  287. self.spine_type)
  288. else:
  289. v1 = self._path.vertices
  290. assert v1.shape == (2, 2), 'unexpected vertices shape'
  291. if self.spine_type in ['left', 'right']:
  292. v1[0, 1] = low
  293. v1[1, 1] = high
  294. elif self.spine_type in ['bottom', 'top']:
  295. v1[0, 0] = low
  296. v1[1, 0] = high
  297. else:
  298. raise ValueError('unable to set bounds for spine "%s"' %
  299. self.spine_type)
  300. @allow_rasterization
  301. def draw(self, renderer):
  302. self._adjust_location()
  303. ret = super().draw(renderer)
  304. self.stale = False
  305. return ret
  306. def set_position(self, position):
  307. """
  308. Set the position of the spine.
  309. Spine position is specified by a 2 tuple of (position type,
  310. amount). The position types are:
  311. * 'outward': place the spine out from the data area by the specified
  312. number of points. (Negative values place the spine inwards.)
  313. * 'axes': place the spine at the specified Axes coordinate (0 to 1).
  314. * 'data': place the spine at the specified data coordinate.
  315. Additionally, shorthand notations define a special positions:
  316. * 'center' -> ('axes', 0.5)
  317. * 'zero' -> ('data', 0.0)
  318. """
  319. if position in ('center', 'zero'): # special positions
  320. pass
  321. else:
  322. if len(position) != 2:
  323. raise ValueError("position should be 'center' or 2-tuple")
  324. if position[0] not in ['outward', 'axes', 'data']:
  325. raise ValueError("position[0] should be one of 'outward', "
  326. "'axes', or 'data' ")
  327. self._position = position
  328. self.set_transform(self.get_spine_transform())
  329. if self.axis is not None:
  330. self.axis.reset_ticks()
  331. self.stale = True
  332. def get_position(self):
  333. """Return the spine position."""
  334. self._ensure_position_is_set()
  335. return self._position
  336. def get_spine_transform(self):
  337. """Return the spine transform."""
  338. self._ensure_position_is_set()
  339. position = self._position
  340. if isinstance(position, str):
  341. if position == 'center':
  342. position = ('axes', 0.5)
  343. elif position == 'zero':
  344. position = ('data', 0)
  345. assert len(position) == 2, 'position should be 2-tuple'
  346. position_type, amount = position
  347. cbook._check_in_list(['axes', 'outward', 'data'],
  348. position_type=position_type)
  349. if self.spine_type in ['left', 'right']:
  350. base_transform = self.axes.get_yaxis_transform(which='grid')
  351. elif self.spine_type in ['top', 'bottom']:
  352. base_transform = self.axes.get_xaxis_transform(which='grid')
  353. else:
  354. raise ValueError(f'unknown spine spine_type: {self.spine_type!r}')
  355. if position_type == 'outward':
  356. if amount == 0: # short circuit commonest case
  357. return base_transform
  358. else:
  359. offset_vec = {'left': (-1, 0), 'right': (1, 0),
  360. 'bottom': (0, -1), 'top': (0, 1),
  361. }[self.spine_type]
  362. # calculate x and y offset in dots
  363. offset_dots = amount * np.array(offset_vec) / 72
  364. return (base_transform
  365. + mtransforms.ScaledTranslation(
  366. *offset_dots, self.figure.dpi_scale_trans))
  367. elif position_type == 'axes':
  368. if self.spine_type in ['left', 'right']:
  369. # keep y unchanged, fix x at amount
  370. return (mtransforms.Affine2D.from_values(0, 0, 0, 1, amount, 0)
  371. + base_transform)
  372. elif self.spine_type in ['bottom', 'top']:
  373. # keep x unchanged, fix y at amount
  374. return (mtransforms.Affine2D.from_values(1, 0, 0, 0, 0, amount)
  375. + base_transform)
  376. elif position_type == 'data':
  377. if self.spine_type in ('right', 'top'):
  378. # The right and top spines have a default position of 1 in
  379. # axes coordinates. When specifying the position in data
  380. # coordinates, we need to calculate the position relative to 0.
  381. amount -= 1
  382. if self.spine_type in ('left', 'right'):
  383. return mtransforms.blended_transform_factory(
  384. mtransforms.Affine2D().translate(amount, 0)
  385. + self.axes.transData,
  386. self.axes.transData)
  387. elif self.spine_type in ('bottom', 'top'):
  388. return mtransforms.blended_transform_factory(
  389. self.axes.transData,
  390. mtransforms.Affine2D().translate(0, amount)
  391. + self.axes.transData)
  392. def set_bounds(self, low=None, high=None):
  393. """
  394. Set the spine bounds.
  395. Parameters
  396. ----------
  397. low : float or None, optional
  398. The lower spine bound. Passing *None* leaves the limit unchanged.
  399. The bounds may also be passed as the tuple (*low*, *high*) as the
  400. first positional argument.
  401. .. ACCEPTS: (low: float, high: float)
  402. high : float or None, optional
  403. The higher spine bound. Passing *None* leaves the limit unchanged.
  404. """
  405. if self.spine_type == 'circle':
  406. raise ValueError(
  407. 'set_bounds() method incompatible with circular spines')
  408. if high is None and np.iterable(low):
  409. low, high = low
  410. old_low, old_high = self.get_bounds() or (None, None)
  411. if low is None:
  412. low = old_low
  413. if high is None:
  414. high = old_high
  415. self._bounds = (low, high)
  416. self.stale = True
  417. def get_bounds(self):
  418. """Get the bounds of the spine."""
  419. return self._bounds
  420. @classmethod
  421. def linear_spine(cls, axes, spine_type, **kwargs):
  422. """Create and return a linear `Spine`."""
  423. # all values of 0.999 get replaced upon call to set_bounds()
  424. if spine_type == 'left':
  425. path = mpath.Path([(0.0, 0.999), (0.0, 0.999)])
  426. elif spine_type == 'right':
  427. path = mpath.Path([(1.0, 0.999), (1.0, 0.999)])
  428. elif spine_type == 'bottom':
  429. path = mpath.Path([(0.999, 0.0), (0.999, 0.0)])
  430. elif spine_type == 'top':
  431. path = mpath.Path([(0.999, 1.0), (0.999, 1.0)])
  432. else:
  433. raise ValueError('unable to make path for spine "%s"' % spine_type)
  434. result = cls(axes, spine_type, path, **kwargs)
  435. result.set_visible(rcParams['axes.spines.{0}'.format(spine_type)])
  436. return result
  437. @classmethod
  438. def arc_spine(cls, axes, spine_type, center, radius, theta1, theta2,
  439. **kwargs):
  440. """Create and return an arc `Spine`."""
  441. path = mpath.Path.arc(theta1, theta2)
  442. result = cls(axes, spine_type, path, **kwargs)
  443. result.set_patch_arc(center, radius, theta1, theta2)
  444. return result
  445. @classmethod
  446. def circular_spine(cls, axes, center, radius, **kwargs):
  447. """Create and return a circular `Spine`."""
  448. path = mpath.Path.unit_circle()
  449. spine_type = 'circle'
  450. result = cls(axes, spine_type, path, **kwargs)
  451. result.set_patch_circle(center, radius)
  452. return result
  453. def set_color(self, c):
  454. """
  455. Set the edgecolor.
  456. Parameters
  457. ----------
  458. c : color
  459. Notes
  460. -----
  461. This method does not modify the facecolor (which defaults to "none"),
  462. unlike the `.Patch.set_color` method defined in the parent class. Use
  463. `.Patch.set_facecolor` to set the facecolor.
  464. """
  465. self.set_edgecolor(c)
  466. self.stale = True