interpolate.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. """ Interpolation helpers for timm layers
  2. RegularGridInterpolator from https://github.com/sbarratt/torch_interpolations
  3. Copyright Shane Barratt, Apache 2.0 license
  4. """
  5. import torch
  6. from itertools import product
  7. class RegularGridInterpolator:
  8. """ Interpolate data defined on a rectilinear grid with even or uneven spacing.
  9. Produces similar results to scipy RegularGridInterpolator or interp2d
  10. in 'linear' mode.
  11. Taken from https://github.com/sbarratt/torch_interpolations
  12. """
  13. def __init__(self, points, values):
  14. self.points = points
  15. self.values = values
  16. assert isinstance(self.points, tuple) or isinstance(self.points, list)
  17. assert isinstance(self.values, torch.Tensor)
  18. self.ms = list(self.values.shape)
  19. self.n = len(self.points)
  20. assert len(self.ms) == self.n
  21. for i, p in enumerate(self.points):
  22. assert isinstance(p, torch.Tensor)
  23. assert p.shape[0] == self.values.shape[i]
  24. def __call__(self, points_to_interp):
  25. assert self.points is not None
  26. assert self.values is not None
  27. assert len(points_to_interp) == len(self.points)
  28. K = points_to_interp[0].shape[0]
  29. for x in points_to_interp:
  30. assert x.shape[0] == K
  31. idxs = []
  32. dists = []
  33. overalls = []
  34. for p, x in zip(self.points, points_to_interp):
  35. idx_right = torch.bucketize(x, p)
  36. idx_right[idx_right >= p.shape[0]] = p.shape[0] - 1
  37. idx_left = (idx_right - 1).clamp(0, p.shape[0] - 1)
  38. dist_left = x - p[idx_left]
  39. dist_right = p[idx_right] - x
  40. dist_left[dist_left < 0] = 0.
  41. dist_right[dist_right < 0] = 0.
  42. both_zero = (dist_left == 0) & (dist_right == 0)
  43. dist_left[both_zero] = dist_right[both_zero] = 1.
  44. idxs.append((idx_left, idx_right))
  45. dists.append((dist_left, dist_right))
  46. overalls.append(dist_left + dist_right)
  47. numerator = 0.
  48. for indexer in product([0, 1], repeat=self.n):
  49. as_s = [idx[onoff] for onoff, idx in zip(indexer, idxs)]
  50. bs_s = [dist[1 - onoff] for onoff, dist in zip(indexer, dists)]
  51. numerator += self.values[as_s] * \
  52. torch.prod(torch.stack(bs_s), dim=0)
  53. denominator = torch.prod(torch.stack(overalls), dim=0)
  54. return numerator / denominator