| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203 |
- """
- 2D visualization primitives based on Matplotlib.
- 1) Plot images with `plot_images`.
- 2) Call `plot_keypoints` or `plot_matches` any number of times.
- 3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`.
- """
- import matplotlib
- import matplotlib.patheffects as path_effects
- import matplotlib.pyplot as plt
- import numpy as np
- import torch
- def cm_RdGn(x):
- """Custom colormap: red (0) -> yellow (0.5) -> green (1)."""
- x = np.clip(x, 0, 1)[..., None] * 2
- c = x * np.array([[0, 1.0, 0]]) + (2 - x) * np.array([[1.0, 0, 0]])
- return np.clip(c, 0, 1)
- def cm_BlRdGn(x_):
- """Custom colormap: blue (-1) -> red (0.0) -> green (1)."""
- x = np.clip(x_, 0, 1)[..., None] * 2
- c = x * np.array([[0, 1.0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0, 1.0]])
- xn = -np.clip(x_, -1, 0)[..., None] * 2
- cn = xn * np.array([[0, 0.1, 1, 1.0]]) + (2 - xn) * np.array([[1.0, 0, 0, 1.0]])
- out = np.clip(np.where(x_[..., None] < 0, cn, c), 0, 1)
- return out
- def cm_prune(x_):
- """Custom colormap to visualize pruning"""
- if isinstance(x_, torch.Tensor):
- x_ = x_.cpu().numpy()
- max_i = max(x_)
- norm_x = np.where(x_ == max_i, -1, (x_ - 1) / 9)
- return cm_BlRdGn(norm_x)
- def cm_grad2d(xy):
- """2D grad. colormap: yellow (0, 0) -> green (1, 0) -> red (0, 1) -> blue (1, 1)."""
- tl = np.array([1.0, 0, 0]) # red
- tr = np.array([0, 0.0, 1]) # blue
- ll = np.array([1.0, 1.0, 0]) # yellow
- lr = np.array([0, 1.0, 0]) # green
- xy = np.clip(xy, 0, 1)
- x = xy[..., :1]
- y = xy[..., -1:]
- rgb = (1 - x) * (1 - y) * ll + x * (1 - y) * lr + x * y * tr + (1 - x) * y * tl
- return rgb.clip(0, 1)
- def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True):
- """Plot a set of images horizontally.
- Args:
- imgs: list of NumPy RGB (H, W, 3) or PyTorch RGB (3, H, W) or mono (H, W).
- titles: a list of strings, as titles for each image.
- cmaps: colormaps for monochrome images.
- adaptive: whether the figure size should fit the image aspect ratios.
- """
- # conversion to (H, W, 3) for torch.Tensor
- imgs = [
- (
- img.permute(1, 2, 0).cpu().numpy()
- if (isinstance(img, torch.Tensor) and img.dim() == 3)
- else img
- )
- for img in imgs
- ]
- n = len(imgs)
- if not isinstance(cmaps, (list, tuple)):
- cmaps = [cmaps] * n
- if adaptive:
- ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H
- else:
- ratios = [4 / 3] * n
- figsize = [sum(ratios) * 4.5, 4.5]
- fig, ax = plt.subplots(
- 1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}
- )
- if n == 1:
- ax = [ax]
- for i in range(n):
- ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
- ax[i].get_yaxis().set_ticks([])
- ax[i].get_xaxis().set_ticks([])
- ax[i].set_axis_off()
- for spine in ax[i].spines.values(): # remove frame
- spine.set_visible(False)
- if titles:
- ax[i].set_title(titles[i])
- fig.tight_layout(pad=pad)
- def plot_keypoints(kpts, colors="lime", ps=4, axes=None, a=1.0):
- """Plot keypoints for existing images.
- Args:
- kpts: list of ndarrays of size (N, 2).
- colors: string, or list of list of tuples (one for each keypoints).
- ps: size of the keypoints as float.
- """
- if not isinstance(colors, list):
- colors = [colors] * len(kpts)
- if not isinstance(a, list):
- a = [a] * len(kpts)
- if axes is None:
- axes = plt.gcf().axes
- for ax, k, c, alpha in zip(axes, kpts, colors, a):
- if isinstance(k, torch.Tensor):
- k = k.cpu().numpy()
- ax.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0, alpha=alpha)
- def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1.0, labels=None, axes=None):
- """Plot matches for a pair of existing images.
- Args:
- kpts0, kpts1: corresponding keypoints of size (N, 2).
- color: color of each match, string or RGB tuple. Random if not given.
- lw: width of the lines.
- ps: size of the end points (no endpoint if ps=0)
- indices: indices of the images to draw the matches on.
- a: alpha opacity of the match lines.
- """
- fig = plt.gcf()
- if axes is None:
- ax = fig.axes
- ax0, ax1 = ax[0], ax[1]
- else:
- ax0, ax1 = axes
- if isinstance(kpts0, torch.Tensor):
- kpts0 = kpts0.cpu().numpy()
- if isinstance(kpts1, torch.Tensor):
- kpts1 = kpts1.cpu().numpy()
- assert len(kpts0) == len(kpts1)
- if color is None:
- kpts_norm = (kpts0 - kpts0.min(axis=0, keepdims=True)) / np.ptp(
- kpts0, axis=0, keepdims=True
- )
- color = cm_grad2d(kpts_norm) # gradient color
- elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
- color = [color] * len(kpts0)
- if lw > 0:
- for i in range(len(kpts0)):
- line = matplotlib.patches.ConnectionPatch(
- xyA=(kpts0[i, 0], kpts0[i, 1]),
- xyB=(kpts1[i, 0], kpts1[i, 1]),
- coordsA=ax0.transData,
- coordsB=ax1.transData,
- axesA=ax0,
- axesB=ax1,
- zorder=1,
- color=color[i],
- linewidth=lw,
- clip_on=True,
- alpha=a,
- label=None if labels is None else labels[i],
- picker=5.0,
- )
- line.set_annotation_clip(True)
- fig.add_artist(line)
- # freeze the axes to prevent the transform to change
- ax0.autoscale(enable=False)
- ax1.autoscale(enable=False)
- if ps > 0:
- ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
- ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
- def add_text(
- idx,
- text,
- pos=(0.01, 0.99),
- fs=15,
- color="w",
- lcolor="k",
- lwidth=2,
- ha="left",
- va="top",
- ):
- ax = plt.gcf().axes[idx]
- t = ax.text(
- *pos, text, fontsize=fs, ha=ha, va=va, color=color, transform=ax.transAxes
- )
- if lcolor is not None:
- t.set_path_effects(
- [
- path_effects.Stroke(linewidth=lwidth, foreground=lcolor),
- path_effects.Normal(),
- ]
- )
- def save_plot(path, **kw):
- """Save the current figure without any white margin."""
- plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw)
|