cls_postprocess.py 927 B

123456789101112131415161718192021222324252627282930
  1. # import paddle
  2. class ClsPostProcess(object):
  3. """ Convert between text-label and text-index """
  4. def __init__(self, label_list=None, key=None, **kwargs):
  5. super(ClsPostProcess, self).__init__()
  6. self.label_list = label_list
  7. self.key = key
  8. def __call__(self, preds, label=None, *args, **kwargs):
  9. if self.key is not None:
  10. preds = preds[self.key]
  11. label_list = self.label_list
  12. if label_list is None:
  13. label_list = {idx: idx for idx in range(preds.shape[-1])}
  14. # if isinstance(preds, paddle.Tensor):
  15. # preds = preds.numpy()
  16. pred_idxs = preds.argmax(axis=1)
  17. decode_out = [(label_list[idx], preds[i, idx])
  18. for i, idx in enumerate(pred_idxs)]
  19. if label is None:
  20. return decode_out
  21. label = [(label_list[idx], 1.0) for idx in label]
  22. return decode_out, label