utils.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import logging
  2. import os
  3. import torch
  4. from torchvision import transforms
  5. import numpy as np
  6. import random
  7. import cv2
  8. from PIL import Image
  9. def path_to_image(path, size=(1024, 1024), color_type=['rgb', 'gray'][0]):
  10. if color_type.lower() == 'rgb':
  11. image = cv2.imread(path)
  12. elif color_type.lower() == 'gray':
  13. image = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
  14. else:
  15. print('Select the color_type to return, either to RGB or gray image.')
  16. return
  17. if size:
  18. image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR)
  19. if color_type.lower() == 'rgb':
  20. image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)).convert('RGB')
  21. else:
  22. image = Image.fromarray(image).convert('L')
  23. return image
  24. def check_state_dict(state_dict, unwanted_prefixes=['module.', '_orig_mod.']):
  25. for k, v in list(state_dict.items()):
  26. prefix_length = 0
  27. for unwanted_prefix in unwanted_prefixes:
  28. if k[prefix_length:].startswith(unwanted_prefix):
  29. prefix_length += len(unwanted_prefix)
  30. state_dict[k[prefix_length:]] = state_dict.pop(k)
  31. return state_dict
  32. def generate_smoothed_gt(gts):
  33. epsilon = 0.001
  34. new_gts = (1-epsilon)*gts+epsilon/2
  35. return new_gts
  36. class Logger():
  37. def __init__(self, path="log.txt"):
  38. self.logger = logging.getLogger('BiRefNet')
  39. self.file_handler = logging.FileHandler(path, "w")
  40. self.stdout_handler = logging.StreamHandler()
  41. self.stdout_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
  42. self.file_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
  43. self.logger.addHandler(self.file_handler)
  44. self.logger.addHandler(self.stdout_handler)
  45. self.logger.setLevel(logging.INFO)
  46. self.logger.propagate = False
  47. def info(self, txt):
  48. self.logger.info(txt)
  49. def close(self):
  50. self.file_handler.close()
  51. self.stdout_handler.close()
  52. class AverageMeter(object):
  53. """Computes and stores the average and current value"""
  54. def __init__(self):
  55. self.reset()
  56. def reset(self):
  57. self.val = 0.0
  58. self.avg = 0.0
  59. self.sum = 0.0
  60. self.count = 0.0
  61. def update(self, val, n=1):
  62. self.val = val
  63. self.sum += val * n
  64. self.count += n
  65. self.avg = self.sum / self.count
  66. def save_checkpoint(state, path, filename="latest.pth"):
  67. torch.save(state, os.path.join(path, filename))
  68. def save_tensor_img(tenor_im, path):
  69. im = tenor_im.cpu().clone()
  70. im = im.squeeze(0)
  71. tensor2pil = transforms.ToPILImage()
  72. im = tensor2pil(im)
  73. im.save(path)
  74. def set_seed(seed):
  75. torch.manual_seed(seed)
  76. torch.cuda.manual_seed_all(seed)
  77. np.random.seed(seed)
  78. random.seed(seed)
  79. torch.backends.cudnn.deterministic = True