grid_finder.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. import numpy as np
  2. from matplotlib import cbook, ticker as mticker
  3. from matplotlib.transforms import Bbox, Transform
  4. from .clip_path import clip_line_to_rect
  5. def _deprecate_factor_none(factor):
  6. # After the deprecation period, calls to _deprecate_factor_none can just be
  7. # removed.
  8. if factor is None:
  9. cbook.warn_deprecated(
  10. "3.2", message="factor=None is deprecated since %(since)s and "
  11. "support will be removed %(removal)s; use/return factor=1 instead")
  12. factor = 1
  13. return factor
  14. class ExtremeFinderSimple:
  15. """
  16. A helper class to figure out the range of grid lines that need to be drawn.
  17. """
  18. def __init__(self, nx, ny):
  19. """
  20. Parameters
  21. ----------
  22. nx, ny : int
  23. The number of samples in each direction.
  24. """
  25. self.nx = nx
  26. self.ny = ny
  27. def __call__(self, transform_xy, x1, y1, x2, y2):
  28. """
  29. Compute an approximation of the bounding box obtained by applying
  30. *transform_xy* to the box delimited by ``(x1, y1, x2, y2)``.
  31. The intended use is to have ``(x1, y1, x2, y2)`` in axes coordinates,
  32. and have *transform_xy* be the transform from axes coordinates to data
  33. coordinates; this method then returns the range of data coordinates
  34. that span the actual axes.
  35. The computation is done by sampling ``nx * ny`` equispaced points in
  36. the ``(x1, y1, x2, y2)`` box and finding the resulting points with
  37. extremal coordinates; then adding some padding to take into account the
  38. finite sampling.
  39. As each sampling step covers a relative range of *1/nx* or *1/ny*,
  40. the padding is computed by expanding the span covered by the extremal
  41. coordinates by these fractions.
  42. """
  43. x, y = np.meshgrid(
  44. np.linspace(x1, x2, self.nx), np.linspace(y1, y2, self.ny))
  45. xt, yt = transform_xy(np.ravel(x), np.ravel(y))
  46. return self._add_pad(xt.min(), xt.max(), yt.min(), yt.max())
  47. def _add_pad(self, x_min, x_max, y_min, y_max):
  48. """Perform the padding mentioned in `__call__`."""
  49. dx = (x_max - x_min) / self.nx
  50. dy = (y_max - y_min) / self.ny
  51. return x_min - dx, x_max + dx, y_min - dy, y_max + dy
  52. class GridFinder:
  53. def __init__(self,
  54. transform,
  55. extreme_finder=None,
  56. grid_locator1=None,
  57. grid_locator2=None,
  58. tick_formatter1=None,
  59. tick_formatter2=None):
  60. """
  61. transform : transform from the image coordinate (which will be
  62. the transData of the axes to the world coordinate.
  63. or transform = (transform_xy, inv_transform_xy)
  64. locator1, locator2 : grid locator for 1st and 2nd axis.
  65. """
  66. if extreme_finder is None:
  67. extreme_finder = ExtremeFinderSimple(20, 20)
  68. if grid_locator1 is None:
  69. grid_locator1 = MaxNLocator()
  70. if grid_locator2 is None:
  71. grid_locator2 = MaxNLocator()
  72. if tick_formatter1 is None:
  73. tick_formatter1 = FormatterPrettyPrint()
  74. if tick_formatter2 is None:
  75. tick_formatter2 = FormatterPrettyPrint()
  76. self.extreme_finder = extreme_finder
  77. self.grid_locator1 = grid_locator1
  78. self.grid_locator2 = grid_locator2
  79. self.tick_formatter1 = tick_formatter1
  80. self.tick_formatter2 = tick_formatter2
  81. self.update_transform(transform)
  82. def get_grid_info(self, x1, y1, x2, y2):
  83. """
  84. lon_values, lat_values : list of grid values. if integer is given,
  85. rough number of grids in each direction.
  86. """
  87. extremes = self.extreme_finder(self.inv_transform_xy, x1, y1, x2, y2)
  88. # min & max rage of lat (or lon) for each grid line will be drawn.
  89. # i.e., gridline of lon=0 will be drawn from lat_min to lat_max.
  90. lon_min, lon_max, lat_min, lat_max = extremes
  91. lon_levs, lon_n, lon_factor = self.grid_locator1(lon_min, lon_max)
  92. lat_levs, lat_n, lat_factor = self.grid_locator2(lat_min, lat_max)
  93. lon_values = lon_levs[:lon_n] / _deprecate_factor_none(lon_factor)
  94. lat_values = lat_levs[:lat_n] / _deprecate_factor_none(lat_factor)
  95. lon_lines, lat_lines = self._get_raw_grid_lines(lon_values,
  96. lat_values,
  97. lon_min, lon_max,
  98. lat_min, lat_max)
  99. ddx = (x2-x1)*1.e-10
  100. ddy = (y2-y1)*1.e-10
  101. bb = Bbox.from_extents(x1-ddx, y1-ddy, x2+ddx, y2+ddy)
  102. grid_info = {
  103. "extremes": extremes,
  104. "lon_lines": lon_lines,
  105. "lat_lines": lat_lines,
  106. "lon": self._clip_grid_lines_and_find_ticks(
  107. lon_lines, lon_values, lon_levs, bb),
  108. "lat": self._clip_grid_lines_and_find_ticks(
  109. lat_lines, lat_values, lat_levs, bb),
  110. }
  111. tck_labels = grid_info["lon"]["tick_labels"] = {}
  112. for direction in ["left", "bottom", "right", "top"]:
  113. levs = grid_info["lon"]["tick_levels"][direction]
  114. tck_labels[direction] = self.tick_formatter1(
  115. direction, lon_factor, levs)
  116. tck_labels = grid_info["lat"]["tick_labels"] = {}
  117. for direction in ["left", "bottom", "right", "top"]:
  118. levs = grid_info["lat"]["tick_levels"][direction]
  119. tck_labels[direction] = self.tick_formatter2(
  120. direction, lat_factor, levs)
  121. return grid_info
  122. def _get_raw_grid_lines(self,
  123. lon_values, lat_values,
  124. lon_min, lon_max, lat_min, lat_max):
  125. lons_i = np.linspace(lon_min, lon_max, 100) # for interpolation
  126. lats_i = np.linspace(lat_min, lat_max, 100)
  127. lon_lines = [self.transform_xy(np.full_like(lats_i, lon), lats_i)
  128. for lon in lon_values]
  129. lat_lines = [self.transform_xy(lons_i, np.full_like(lons_i, lat))
  130. for lat in lat_values]
  131. return lon_lines, lat_lines
  132. def _clip_grid_lines_and_find_ticks(self, lines, values, levs, bb):
  133. gi = {
  134. "values": [],
  135. "levels": [],
  136. "tick_levels": dict(left=[], bottom=[], right=[], top=[]),
  137. "tick_locs": dict(left=[], bottom=[], right=[], top=[]),
  138. "lines": [],
  139. }
  140. tck_levels = gi["tick_levels"]
  141. tck_locs = gi["tick_locs"]
  142. for (lx, ly), v, lev in zip(lines, values, levs):
  143. xy, tcks = clip_line_to_rect(lx, ly, bb)
  144. if not xy:
  145. continue
  146. gi["levels"].append(v)
  147. gi["lines"].append(xy)
  148. for tck, direction in zip(tcks,
  149. ["left", "bottom", "right", "top"]):
  150. for t in tck:
  151. tck_levels[direction].append(lev)
  152. tck_locs[direction].append(t)
  153. return gi
  154. def update_transform(self, aux_trans):
  155. if isinstance(aux_trans, Transform):
  156. def transform_xy(x, y):
  157. ll1 = np.column_stack([x, y])
  158. ll2 = aux_trans.transform(ll1)
  159. lon, lat = ll2[:, 0], ll2[:, 1]
  160. return lon, lat
  161. def inv_transform_xy(x, y):
  162. ll1 = np.column_stack([x, y])
  163. ll2 = aux_trans.inverted().transform(ll1)
  164. lon, lat = ll2[:, 0], ll2[:, 1]
  165. return lon, lat
  166. else:
  167. transform_xy, inv_transform_xy = aux_trans
  168. self.transform_xy = transform_xy
  169. self.inv_transform_xy = inv_transform_xy
  170. def update(self, **kw):
  171. for k in kw:
  172. if k in ["extreme_finder",
  173. "grid_locator1",
  174. "grid_locator2",
  175. "tick_formatter1",
  176. "tick_formatter2"]:
  177. setattr(self, k, kw[k])
  178. else:
  179. raise ValueError("Unknown update property '%s'" % k)
  180. @cbook.deprecated("3.2")
  181. class GridFinderBase(GridFinder):
  182. def __init__(self,
  183. extreme_finder,
  184. grid_locator1=None,
  185. grid_locator2=None,
  186. tick_formatter1=None,
  187. tick_formatter2=None):
  188. super().__init__((None, None), extreme_finder,
  189. grid_locator1, grid_locator2,
  190. tick_formatter1, tick_formatter2)
  191. class MaxNLocator(mticker.MaxNLocator):
  192. def __init__(self, nbins=10, steps=None,
  193. trim=True,
  194. integer=False,
  195. symmetric=False,
  196. prune=None):
  197. # trim argument has no effect. It has been left for API compatibility
  198. mticker.MaxNLocator.__init__(self, nbins, steps=steps,
  199. integer=integer,
  200. symmetric=symmetric, prune=prune)
  201. self.create_dummy_axis()
  202. self._factor = 1
  203. def __call__(self, v1, v2):
  204. self.set_bounds(v1 * self._factor, v2 * self._factor)
  205. locs = mticker.MaxNLocator.__call__(self)
  206. return np.array(locs), len(locs), self._factor
  207. @cbook.deprecated("3.3")
  208. def set_factor(self, f):
  209. self._factor = _deprecate_factor_none(f)
  210. class FixedLocator:
  211. def __init__(self, locs):
  212. self._locs = locs
  213. self._factor = 1
  214. def __call__(self, v1, v2):
  215. v1, v2 = sorted([v1 * self._factor, v2 * self._factor])
  216. locs = np.array([l for l in self._locs if v1 <= l <= v2])
  217. return locs, len(locs), self._factor
  218. @cbook.deprecated("3.3")
  219. def set_factor(self, f):
  220. self._factor = _deprecate_factor_none(f)
  221. # Tick Formatter
  222. class FormatterPrettyPrint:
  223. def __init__(self, useMathText=True):
  224. self._fmt = mticker.ScalarFormatter(
  225. useMathText=useMathText, useOffset=False)
  226. self._fmt.create_dummy_axis()
  227. def __call__(self, direction, factor, values):
  228. return self._fmt.format_ticks(values)
  229. class DictFormatter:
  230. def __init__(self, format_dict, formatter=None):
  231. """
  232. format_dict : dictionary for format strings to be used.
  233. formatter : fall-back formatter
  234. """
  235. super().__init__()
  236. self._format_dict = format_dict
  237. self._fallback_formatter = formatter
  238. def __call__(self, direction, factor, values):
  239. """
  240. factor is ignored if value is found in the dictionary
  241. """
  242. if self._fallback_formatter:
  243. fallback_strings = self._fallback_formatter(
  244. direction, factor, values)
  245. else:
  246. fallback_strings = [""] * len(values)
  247. return [self._format_dict.get(k, v)
  248. for k, v in zip(values, fallback_strings)]