axis3d.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. # axis3d.py, original mplot3d version by John Porter
  2. # Created: 23 Sep 2005
  3. # Parts rewritten by Reinier Heeres <reinier@heeres.eu>
  4. import numpy as np
  5. import matplotlib.transforms as mtransforms
  6. from matplotlib import (
  7. artist, lines as mlines, axis as maxis, patches as mpatches, rcParams)
  8. from . import art3d, proj3d
  9. def move_from_center(coord, centers, deltas, axmask=(True, True, True)):
  10. """
  11. For each coordinate where *axmask* is True, move *coord* away from
  12. *centers* by *deltas*.
  13. """
  14. coord = np.asarray(coord)
  15. return coord + axmask * np.copysign(1, coord - centers) * deltas
  16. def tick_update_position(tick, tickxs, tickys, labelpos):
  17. """Update tick line and label position and style."""
  18. tick.label1.set_position(labelpos)
  19. tick.label2.set_position(labelpos)
  20. tick.tick1line.set_visible(True)
  21. tick.tick2line.set_visible(False)
  22. tick.tick1line.set_linestyle('-')
  23. tick.tick1line.set_marker('')
  24. tick.tick1line.set_data(tickxs, tickys)
  25. tick.gridline.set_data(0, 0)
  26. class Axis(maxis.XAxis):
  27. """An Axis class for the 3D plots."""
  28. # These points from the unit cube make up the x, y and z-planes
  29. _PLANES = (
  30. (0, 3, 7, 4), (1, 2, 6, 5), # yz planes
  31. (0, 1, 5, 4), (3, 2, 6, 7), # xz planes
  32. (0, 1, 2, 3), (4, 5, 6, 7), # xy planes
  33. )
  34. # Some properties for the axes
  35. _AXINFO = {
  36. 'x': {'i': 0, 'tickdir': 1, 'juggled': (1, 0, 2),
  37. 'color': (0.95, 0.95, 0.95, 0.5)},
  38. 'y': {'i': 1, 'tickdir': 0, 'juggled': (0, 1, 2),
  39. 'color': (0.90, 0.90, 0.90, 0.5)},
  40. 'z': {'i': 2, 'tickdir': 0, 'juggled': (0, 2, 1),
  41. 'color': (0.925, 0.925, 0.925, 0.5)},
  42. }
  43. def __init__(self, adir, v_intervalx, d_intervalx, axes, *args,
  44. rotate_label=None, **kwargs):
  45. # adir identifies which axes this is
  46. self.adir = adir
  47. # This is a temporary member variable.
  48. # Do not depend on this existing in future releases!
  49. self._axinfo = self._AXINFO[adir].copy()
  50. if rcParams['_internal.classic_mode']:
  51. self._axinfo.update({
  52. 'label': {'va': 'center', 'ha': 'center'},
  53. 'tick': {
  54. 'inward_factor': 0.2,
  55. 'outward_factor': 0.1,
  56. 'linewidth': {
  57. True: rcParams['lines.linewidth'], # major
  58. False: rcParams['lines.linewidth'], # minor
  59. }
  60. },
  61. 'axisline': {'linewidth': 0.75, 'color': (0, 0, 0, 1)},
  62. 'grid': {
  63. 'color': (0.9, 0.9, 0.9, 1),
  64. 'linewidth': 1.0,
  65. 'linestyle': '-',
  66. },
  67. })
  68. else:
  69. self._axinfo.update({
  70. 'label': {'va': 'center', 'ha': 'center'},
  71. 'tick': {
  72. 'inward_factor': 0.2,
  73. 'outward_factor': 0.1,
  74. 'linewidth': {
  75. True: ( # major
  76. rcParams['xtick.major.width'] if adir in 'xz' else
  77. rcParams['ytick.major.width']),
  78. False: ( # minor
  79. rcParams['xtick.minor.width'] if adir in 'xz' else
  80. rcParams['ytick.minor.width']),
  81. }
  82. },
  83. 'axisline': {
  84. 'linewidth': rcParams['axes.linewidth'],
  85. 'color': rcParams['axes.edgecolor'],
  86. },
  87. 'grid': {
  88. 'color': rcParams['grid.color'],
  89. 'linewidth': rcParams['grid.linewidth'],
  90. 'linestyle': rcParams['grid.linestyle'],
  91. },
  92. })
  93. maxis.XAxis.__init__(self, axes, *args, **kwargs)
  94. # data and viewing intervals for this direction
  95. self.d_interval = d_intervalx
  96. self.v_interval = v_intervalx
  97. self.set_rotate_label(rotate_label)
  98. def init3d(self):
  99. self.line = mlines.Line2D(
  100. xdata=(0, 0), ydata=(0, 0),
  101. linewidth=self._axinfo['axisline']['linewidth'],
  102. color=self._axinfo['axisline']['color'],
  103. antialiased=True)
  104. # Store dummy data in Polygon object
  105. self.pane = mpatches.Polygon(
  106. np.array([[0, 0], [0, 1], [1, 0], [0, 0]]),
  107. closed=False, alpha=0.8, facecolor='k', edgecolor='k')
  108. self.set_pane_color(self._axinfo['color'])
  109. self.axes._set_artist_props(self.line)
  110. self.axes._set_artist_props(self.pane)
  111. self.gridlines = art3d.Line3DCollection([])
  112. self.axes._set_artist_props(self.gridlines)
  113. self.axes._set_artist_props(self.label)
  114. self.axes._set_artist_props(self.offsetText)
  115. # Need to be able to place the label at the correct location
  116. self.label._transform = self.axes.transData
  117. self.offsetText._transform = self.axes.transData
  118. def get_major_ticks(self, numticks=None):
  119. ticks = maxis.XAxis.get_major_ticks(self, numticks)
  120. for t in ticks:
  121. for obj in [
  122. t.tick1line, t.tick2line, t.gridline, t.label1, t.label2]:
  123. obj.set_transform(self.axes.transData)
  124. return ticks
  125. def get_minor_ticks(self, numticks=None):
  126. ticks = maxis.XAxis.get_minor_ticks(self, numticks)
  127. for t in ticks:
  128. for obj in [
  129. t.tick1line, t.tick2line, t.gridline, t.label1, t.label2]:
  130. obj.set_transform(self.axes.transData)
  131. return ticks
  132. def set_pane_pos(self, xys):
  133. xys = np.asarray(xys)
  134. xys = xys[:, :2]
  135. self.pane.xy = xys
  136. self.stale = True
  137. def set_pane_color(self, color):
  138. """Set pane color to a RGBA tuple."""
  139. self._axinfo['color'] = color
  140. self.pane.set_edgecolor(color)
  141. self.pane.set_facecolor(color)
  142. self.pane.set_alpha(color[-1])
  143. self.stale = True
  144. def set_rotate_label(self, val):
  145. """
  146. Whether to rotate the axis label: True, False or None.
  147. If set to None the label will be rotated if longer than 4 chars.
  148. """
  149. self._rotate_label = val
  150. self.stale = True
  151. def get_rotate_label(self, text):
  152. if self._rotate_label is not None:
  153. return self._rotate_label
  154. else:
  155. return len(text) > 4
  156. def _get_coord_info(self, renderer):
  157. mins, maxs = np.array([
  158. self.axes.get_xbound(),
  159. self.axes.get_ybound(),
  160. self.axes.get_zbound(),
  161. ]).T
  162. centers = (maxs + mins) / 2.
  163. deltas = (maxs - mins) / 12.
  164. mins = mins - deltas / 4.
  165. maxs = maxs + deltas / 4.
  166. vals = mins[0], maxs[0], mins[1], maxs[1], mins[2], maxs[2]
  167. tc = self.axes.tunit_cube(vals, renderer.M)
  168. avgz = [tc[p1][2] + tc[p2][2] + tc[p3][2] + tc[p4][2]
  169. for p1, p2, p3, p4 in self._PLANES]
  170. highs = np.array([avgz[2*i] < avgz[2*i+1] for i in range(3)])
  171. return mins, maxs, centers, deltas, tc, highs
  172. def draw_pane(self, renderer):
  173. renderer.open_group('pane3d', gid=self.get_gid())
  174. mins, maxs, centers, deltas, tc, highs = self._get_coord_info(renderer)
  175. info = self._axinfo
  176. index = info['i']
  177. if not highs[index]:
  178. plane = self._PLANES[2 * index]
  179. else:
  180. plane = self._PLANES[2 * index + 1]
  181. xys = [tc[p] for p in plane]
  182. self.set_pane_pos(xys)
  183. self.pane.draw(renderer)
  184. renderer.close_group('pane3d')
  185. @artist.allow_rasterization
  186. def draw(self, renderer):
  187. self.label._transform = self.axes.transData
  188. renderer.open_group('axis3d', gid=self.get_gid())
  189. ticks = self._update_ticks()
  190. info = self._axinfo
  191. index = info['i']
  192. mins, maxs, centers, deltas, tc, highs = self._get_coord_info(renderer)
  193. # Determine grid lines
  194. minmax = np.where(highs, maxs, mins)
  195. maxmin = np.where(highs, mins, maxs)
  196. # Draw main axis line
  197. juggled = info['juggled']
  198. edgep1 = minmax.copy()
  199. edgep1[juggled[0]] = maxmin[juggled[0]]
  200. edgep2 = edgep1.copy()
  201. edgep2[juggled[1]] = maxmin[juggled[1]]
  202. pep = np.asarray(
  203. proj3d.proj_trans_points([edgep1, edgep2], renderer.M))
  204. centpt = proj3d.proj_transform(*centers, renderer.M)
  205. self.line.set_data(pep[0], pep[1])
  206. self.line.draw(renderer)
  207. # Grid points where the planes meet
  208. xyz0 = np.tile(minmax, (len(ticks), 1))
  209. xyz0[:, index] = [tick.get_loc() for tick in ticks]
  210. # Draw labels
  211. # The transAxes transform is used because the Text object
  212. # rotates the text relative to the display coordinate system.
  213. # Therefore, if we want the labels to remain parallel to the
  214. # axis regardless of the aspect ratio, we need to convert the
  215. # edge points of the plane to display coordinates and calculate
  216. # an angle from that.
  217. # TODO: Maybe Text objects should handle this themselves?
  218. dx, dy = (self.axes.transAxes.transform([pep[0:2, 1]]) -
  219. self.axes.transAxes.transform([pep[0:2, 0]]))[0]
  220. lxyz = 0.5 * (edgep1 + edgep2)
  221. # A rough estimate; points are ambiguous since 3D plots rotate
  222. ax_scale = self.axes.bbox.size / self.figure.bbox.size
  223. ax_inches = np.multiply(ax_scale, self.figure.get_size_inches())
  224. ax_points_estimate = sum(72. * ax_inches)
  225. deltas_per_point = 48 / ax_points_estimate
  226. default_offset = 21.
  227. labeldeltas = (
  228. (self.labelpad + default_offset) * deltas_per_point * deltas)
  229. axmask = [True, True, True]
  230. axmask[index] = False
  231. lxyz = move_from_center(lxyz, centers, labeldeltas, axmask)
  232. tlx, tly, tlz = proj3d.proj_transform(*lxyz, renderer.M)
  233. self.label.set_position((tlx, tly))
  234. if self.get_rotate_label(self.label.get_text()):
  235. angle = art3d._norm_text_angle(np.rad2deg(np.arctan2(dy, dx)))
  236. self.label.set_rotation(angle)
  237. self.label.set_va(info['label']['va'])
  238. self.label.set_ha(info['label']['ha'])
  239. self.label.draw(renderer)
  240. # Draw Offset text
  241. # Which of the two edge points do we want to
  242. # use for locating the offset text?
  243. if juggled[2] == 2:
  244. outeredgep = edgep1
  245. outerindex = 0
  246. else:
  247. outeredgep = edgep2
  248. outerindex = 1
  249. pos = move_from_center(outeredgep, centers, labeldeltas, axmask)
  250. olx, oly, olz = proj3d.proj_transform(*pos, renderer.M)
  251. self.offsetText.set_text(self.major.formatter.get_offset())
  252. self.offsetText.set_position((olx, oly))
  253. angle = art3d._norm_text_angle(np.rad2deg(np.arctan2(dy, dx)))
  254. self.offsetText.set_rotation(angle)
  255. # Must set rotation mode to "anchor" so that
  256. # the alignment point is used as the "fulcrum" for rotation.
  257. self.offsetText.set_rotation_mode('anchor')
  258. #----------------------------------------------------------------------
  259. # Note: the following statement for determining the proper alignment of
  260. # the offset text. This was determined entirely by trial-and-error
  261. # and should not be in any way considered as "the way". There are
  262. # still some edge cases where alignment is not quite right, but this
  263. # seems to be more of a geometry issue (in other words, I might be
  264. # using the wrong reference points).
  265. #
  266. # (TT, FF, TF, FT) are the shorthand for the tuple of
  267. # (centpt[info['tickdir']] <= pep[info['tickdir'], outerindex],
  268. # centpt[index] <= pep[index, outerindex])
  269. #
  270. # Three-letters (e.g., TFT, FTT) are short-hand for the array of bools
  271. # from the variable 'highs'.
  272. # ---------------------------------------------------------------------
  273. if centpt[info['tickdir']] > pep[info['tickdir'], outerindex]:
  274. # if FT and if highs has an even number of Trues
  275. if (centpt[index] <= pep[index, outerindex]
  276. and np.count_nonzero(highs) % 2 == 0):
  277. # Usually, this means align right, except for the FTT case,
  278. # in which offset for axis 1 and 2 are aligned left.
  279. if highs.tolist() == [False, True, True] and index in (1, 2):
  280. align = 'left'
  281. else:
  282. align = 'right'
  283. else:
  284. # The FF case
  285. align = 'left'
  286. else:
  287. # if TF and if highs has an even number of Trues
  288. if (centpt[index] > pep[index, outerindex]
  289. and np.count_nonzero(highs) % 2 == 0):
  290. # Usually mean align left, except if it is axis 2
  291. if index == 2:
  292. align = 'right'
  293. else:
  294. align = 'left'
  295. else:
  296. # The TT case
  297. align = 'right'
  298. self.offsetText.set_va('center')
  299. self.offsetText.set_ha(align)
  300. self.offsetText.draw(renderer)
  301. if self.axes._draw_grid and len(ticks):
  302. # Grid lines go from the end of one plane through the plane
  303. # intersection (at xyz0) to the end of the other plane. The first
  304. # point (0) differs along dimension index-2 and the last (2) along
  305. # dimension index-1.
  306. lines = np.stack([xyz0, xyz0, xyz0], axis=1)
  307. lines[:, 0, index - 2] = maxmin[index - 2]
  308. lines[:, 2, index - 1] = maxmin[index - 1]
  309. self.gridlines.set_segments(lines)
  310. self.gridlines.set_color(info['grid']['color'])
  311. self.gridlines.set_linewidth(info['grid']['linewidth'])
  312. self.gridlines.set_linestyle(info['grid']['linestyle'])
  313. self.gridlines.draw(renderer, project=True)
  314. # Draw ticks
  315. tickdir = info['tickdir']
  316. tickdelta = deltas[tickdir]
  317. if highs[tickdir]:
  318. ticksign = 1
  319. else:
  320. ticksign = -1
  321. for tick in ticks:
  322. # Get tick line positions
  323. pos = edgep1.copy()
  324. pos[index] = tick.get_loc()
  325. pos[tickdir] = (
  326. edgep1[tickdir]
  327. + info['tick']['outward_factor'] * ticksign * tickdelta)
  328. x1, y1, z1 = proj3d.proj_transform(*pos, renderer.M)
  329. pos[tickdir] = (
  330. edgep1[tickdir]
  331. - info['tick']['inward_factor'] * ticksign * tickdelta)
  332. x2, y2, z2 = proj3d.proj_transform(*pos, renderer.M)
  333. # Get position of label
  334. default_offset = 8. # A rough estimate
  335. labeldeltas = (
  336. (tick.get_pad() + default_offset) * deltas_per_point * deltas)
  337. axmask = [True, True, True]
  338. axmask[index] = False
  339. pos[tickdir] = edgep1[tickdir]
  340. pos = move_from_center(pos, centers, labeldeltas, axmask)
  341. lx, ly, lz = proj3d.proj_transform(*pos, renderer.M)
  342. tick_update_position(tick, (x1, x2), (y1, y2), (lx, ly))
  343. tick.tick1line.set_linewidth(
  344. info['tick']['linewidth'][tick._major])
  345. tick.draw(renderer)
  346. renderer.close_group('axis3d')
  347. self.stale = False
  348. # TODO: Get this to work (more) properly when mplot3d supports the
  349. # transforms framework.
  350. def get_tightbbox(self, renderer, *, for_layout_only=False):
  351. # inherited docstring
  352. if not self.get_visible():
  353. return
  354. # We have to directly access the internal data structures
  355. # (and hope they are up to date) because at draw time we
  356. # shift the ticks and their labels around in (x, y) space
  357. # based on the projection, the current view port, and their
  358. # position in 3D space. If we extend the transforms framework
  359. # into 3D we would not need to do this different book keeping
  360. # than we do in the normal axis
  361. major_locs = self.get_majorticklocs()
  362. minor_locs = self.get_minorticklocs()
  363. ticks = [*self.get_minor_ticks(len(minor_locs)),
  364. *self.get_major_ticks(len(major_locs))]
  365. view_low, view_high = self.get_view_interval()
  366. if view_low > view_high:
  367. view_low, view_high = view_high, view_low
  368. interval_t = self.get_transform().transform([view_low, view_high])
  369. ticks_to_draw = []
  370. for tick in ticks:
  371. try:
  372. loc_t = self.get_transform().transform(tick.get_loc())
  373. except AssertionError:
  374. # Transform.transform doesn't allow masked values but
  375. # some scales might make them, so we need this try/except.
  376. pass
  377. else:
  378. if mtransforms._interval_contains_close(interval_t, loc_t):
  379. ticks_to_draw.append(tick)
  380. ticks = ticks_to_draw
  381. bb_1, bb_2 = self._get_tick_bboxes(ticks, renderer)
  382. other = []
  383. if self.line.get_visible():
  384. other.append(self.line.get_window_extent(renderer))
  385. if (self.label.get_visible() and not for_layout_only and
  386. self.label.get_text()):
  387. other.append(self.label.get_window_extent(renderer))
  388. return mtransforms.Bbox.union([*bb_1, *bb_2, *other])
  389. @property
  390. def d_interval(self):
  391. return self.get_data_interval()
  392. @d_interval.setter
  393. def d_interval(self, minmax):
  394. self.set_data_interval(*minmax)
  395. @property
  396. def v_interval(self):
  397. return self.get_view_interval()
  398. @v_interval.setter
  399. def v_interval(self, minmax):
  400. self.set_view_interval(*minmax)
  401. # Use classes to look at different data limits
  402. class XAxis(Axis):
  403. get_view_interval, set_view_interval = maxis._make_getset_interval(
  404. "view", "xy_viewLim", "intervalx")
  405. get_data_interval, set_data_interval = maxis._make_getset_interval(
  406. "data", "xy_dataLim", "intervalx")
  407. class YAxis(Axis):
  408. get_view_interval, set_view_interval = maxis._make_getset_interval(
  409. "view", "xy_viewLim", "intervaly")
  410. get_data_interval, set_data_interval = maxis._make_getset_interval(
  411. "data", "xy_dataLim", "intervaly")
  412. class ZAxis(Axis):
  413. get_view_interval, set_view_interval = maxis._make_getset_interval(
  414. "view", "zz_viewLim", "intervalx")
  415. get_data_interval, set_data_interval = maxis._make_getset_interval(
  416. "data", "zz_dataLim", "intervalx")