kde.py 404 B

12345678910111213
  1. import torch
  2. def kde(x, std = 0.1, half = True, down = None):
  3. # use a gaussian kernel to estimate density
  4. if half:
  5. x = x.half() # Do it in half precision TODO: remove hardcoding
  6. if down is not None:
  7. scores = (-torch.cdist(x,x[::down])**2/(2*std**2)).exp()
  8. else:
  9. scores = (-torch.cdist(x,x)**2/(2*std**2)).exp()
  10. density = scores.sum(dim=-1)
  11. return density