rrdbnet.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import torch
  18. from torch import nn
  19. from kornia.config import kornia_config
  20. from kornia.core.external import basicsr
  21. from kornia.models.utils import OutputRangePostProcessor
  22. from kornia.utils.download import CachedDownloader
  23. from .base import SuperResolution
  24. __all__ = ["RRDBNetBuilder"]
  25. URLs = {
  26. "RealESRGAN_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
  27. "RealESRNet_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth",
  28. "RealESRGAN_x4plus_anime_6B": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
  29. "RealESRGAN_x2plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
  30. }
  31. class RRDBNetBuilder:
  32. @staticmethod
  33. def build(model_name: str = "RealESRNet_x4plus", pretrained: bool = True) -> SuperResolution:
  34. if model_name == "RealESRGAN_x4plus":
  35. model = basicsr.archs.rrdbnet_arch.RRDBNet( # type: ignore
  36. num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4
  37. )
  38. elif model_name == "RealESRNet_x4plus":
  39. model = basicsr.archs.rrdbnet_arch.RRDBNet( # type: ignore
  40. num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4
  41. )
  42. elif model_name == "RealESRGAN_x4plus_anime_6B":
  43. model = basicsr.archs.rrdbnet_arch.RRDBNet( # type: ignore
  44. num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4
  45. )
  46. elif model_name == "RealESRGAN_x2plus":
  47. model = basicsr.archs.rrdbnet_arch.RRDBNet( # type: ignore
  48. num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2
  49. )
  50. else:
  51. raise ValueError(
  52. f"Model {model_name} not found. Please choose from "
  53. "'RealESRGAN_x4plus', 'RealESRNet_x4plus', 'RealESRGAN_x4plus_anime_6B', 'RealESRGAN_x2plus'."
  54. )
  55. model_path = None
  56. if pretrained:
  57. url = URLs[model_name]
  58. model_path = CachedDownloader.download_to_cache(
  59. url, model_name, download=True, suffix=".pth", cache_dir=kornia_config.hub_onnx_dir
  60. )
  61. model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))["params_ema"], strict=True)
  62. model.eval()
  63. return SuperResolution(
  64. model,
  65. pre_processor=nn.Identity(),
  66. post_processor=OutputRangePostProcessor(min_val=0.0, max_val=1.0),
  67. name=model_name,
  68. )