viz2d.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. """
  2. 2D visualization primitives based on Matplotlib.
  3. 1) Plot images with `plot_images`.
  4. 2) Call `plot_keypoints` or `plot_matches` any number of times.
  5. 3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`.
  6. """
  7. import matplotlib
  8. import matplotlib.patheffects as path_effects
  9. import matplotlib.pyplot as plt
  10. import numpy as np
  11. import torch
  12. def cm_RdGn(x):
  13. """Custom colormap: red (0) -> yellow (0.5) -> green (1)."""
  14. x = np.clip(x, 0, 1)[..., None] * 2
  15. c = x * np.array([[0, 1.0, 0]]) + (2 - x) * np.array([[1.0, 0, 0]])
  16. return np.clip(c, 0, 1)
  17. def cm_BlRdGn(x_):
  18. """Custom colormap: blue (-1) -> red (0.0) -> green (1)."""
  19. x = np.clip(x_, 0, 1)[..., None] * 2
  20. c = x * np.array([[0, 1.0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0, 1.0]])
  21. xn = -np.clip(x_, -1, 0)[..., None] * 2
  22. cn = xn * np.array([[0, 0.1, 1, 1.0]]) + (2 - xn) * np.array([[1.0, 0, 0, 1.0]])
  23. out = np.clip(np.where(x_[..., None] < 0, cn, c), 0, 1)
  24. return out
  25. def cm_prune(x_):
  26. """Custom colormap to visualize pruning"""
  27. if isinstance(x_, torch.Tensor):
  28. x_ = x_.cpu().numpy()
  29. max_i = max(x_)
  30. norm_x = np.where(x_ == max_i, -1, (x_ - 1) / 9)
  31. return cm_BlRdGn(norm_x)
  32. def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True):
  33. """Plot a set of images horizontally.
  34. Args:
  35. imgs: list of NumPy RGB (H, W, 3) or PyTorch RGB (3, H, W) or mono (H, W).
  36. titles: a list of strings, as titles for each image.
  37. cmaps: colormaps for monochrome images.
  38. adaptive: whether the figure size should fit the image aspect ratios.
  39. """
  40. # conversion to (H, W, 3) for torch.Tensor
  41. imgs = [
  42. img.permute(1, 2, 0).cpu().numpy()
  43. if (isinstance(img, torch.Tensor) and img.dim() == 3)
  44. else img
  45. for img in imgs
  46. ]
  47. n = len(imgs)
  48. if not isinstance(cmaps, (list, tuple)):
  49. cmaps = [cmaps] * n
  50. if adaptive:
  51. ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H
  52. else:
  53. ratios = [4 / 3] * n
  54. figsize = [sum(ratios) * 4.5, 4.5]
  55. fig, ax = plt.subplots(
  56. 1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}
  57. )
  58. if n == 1:
  59. ax = [ax]
  60. for i in range(n):
  61. ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
  62. ax[i].get_yaxis().set_ticks([])
  63. ax[i].get_xaxis().set_ticks([])
  64. ax[i].set_axis_off()
  65. for spine in ax[i].spines.values(): # remove frame
  66. spine.set_visible(False)
  67. if titles:
  68. ax[i].set_title(titles[i])
  69. fig.tight_layout(pad=pad)
  70. def plot_keypoints(kpts, colors="lime", ps=4, axes=None, a=1.0):
  71. """Plot keypoints for existing images.
  72. Args:
  73. kpts: list of ndarrays of size (N, 2).
  74. colors: string, or list of list of tuples (one for each keypoints).
  75. ps: size of the keypoints as float.
  76. """
  77. if not isinstance(colors, list):
  78. colors = [colors] * len(kpts)
  79. if not isinstance(a, list):
  80. a = [a] * len(kpts)
  81. if axes is None:
  82. axes = plt.gcf().axes
  83. for ax, k, c, alpha in zip(axes, kpts, colors, a):
  84. if isinstance(k, torch.Tensor):
  85. k = k.cpu().numpy()
  86. ax.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0, alpha=alpha)
  87. def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1.0, labels=None, axes=None):
  88. """Plot matches for a pair of existing images.
  89. Args:
  90. kpts0, kpts1: corresponding keypoints of size (N, 2).
  91. color: color of each match, string or RGB tuple. Random if not given.
  92. lw: width of the lines.
  93. ps: size of the end points (no endpoint if ps=0)
  94. indices: indices of the images to draw the matches on.
  95. a: alpha opacity of the match lines.
  96. """
  97. fig = plt.gcf()
  98. if axes is None:
  99. ax = fig.axes
  100. ax0, ax1 = ax[0], ax[1]
  101. else:
  102. ax0, ax1 = axes
  103. if isinstance(kpts0, torch.Tensor):
  104. kpts0 = kpts0.cpu().numpy()
  105. if isinstance(kpts1, torch.Tensor):
  106. kpts1 = kpts1.cpu().numpy()
  107. assert len(kpts0) == len(kpts1)
  108. if color is None:
  109. color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
  110. elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
  111. color = [color] * len(kpts0)
  112. if lw > 0:
  113. for i in range(len(kpts0)):
  114. line = matplotlib.patches.ConnectionPatch(
  115. xyA=(kpts0[i, 0], kpts0[i, 1]),
  116. xyB=(kpts1[i, 0], kpts1[i, 1]),
  117. coordsA=ax0.transData,
  118. coordsB=ax1.transData,
  119. axesA=ax0,
  120. axesB=ax1,
  121. zorder=1,
  122. color=color[i],
  123. linewidth=lw,
  124. clip_on=True,
  125. alpha=a,
  126. label=None if labels is None else labels[i],
  127. picker=5.0,
  128. )
  129. line.set_annotation_clip(True)
  130. fig.add_artist(line)
  131. # freeze the axes to prevent the transform to change
  132. ax0.autoscale(enable=False)
  133. ax1.autoscale(enable=False)
  134. if ps > 0:
  135. ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
  136. ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
  137. def add_text(
  138. idx,
  139. text,
  140. pos=(0.01, 0.99),
  141. fs=15,
  142. color="w",
  143. lcolor="k",
  144. lwidth=2,
  145. ha="left",
  146. va="top",
  147. ):
  148. ax = plt.gcf().axes[idx]
  149. t = ax.text(
  150. *pos, text, fontsize=fs, ha=ha, va=va, color=color, transform=ax.transAxes
  151. )
  152. if lcolor is not None:
  153. t.set_path_effects(
  154. [
  155. path_effects.Stroke(linewidth=lwidth, foreground=lcolor),
  156. path_effects.Normal(),
  157. ]
  158. )
  159. def save_plot(path, **kw):
  160. """Save the current figure without any white margin."""
  161. plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw)