| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import torch
- from torch import nn
- from kornia.config import kornia_config
- from kornia.core.external import basicsr
- from kornia.models.utils import OutputRangePostProcessor
- from kornia.utils.download import CachedDownloader
- from .base import SuperResolution
- __all__ = ["RRDBNetBuilder"]
- URLs = {
- "RealESRGAN_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
- "RealESRNet_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth",
- "RealESRGAN_x4plus_anime_6B": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
- "RealESRGAN_x2plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
- }
- class RRDBNetBuilder:
- @staticmethod
- def build(model_name: str = "RealESRNet_x4plus", pretrained: bool = True) -> SuperResolution:
- if model_name == "RealESRGAN_x4plus":
- model = basicsr.archs.rrdbnet_arch.RRDBNet( # type: ignore
- num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4
- )
- elif model_name == "RealESRNet_x4plus":
- model = basicsr.archs.rrdbnet_arch.RRDBNet( # type: ignore
- num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4
- )
- elif model_name == "RealESRGAN_x4plus_anime_6B":
- model = basicsr.archs.rrdbnet_arch.RRDBNet( # type: ignore
- num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4
- )
- elif model_name == "RealESRGAN_x2plus":
- model = basicsr.archs.rrdbnet_arch.RRDBNet( # type: ignore
- num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2
- )
- else:
- raise ValueError(
- f"Model {model_name} not found. Please choose from "
- "'RealESRGAN_x4plus', 'RealESRNet_x4plus', 'RealESRGAN_x4plus_anime_6B', 'RealESRGAN_x2plus'."
- )
- model_path = None
- if pretrained:
- url = URLs[model_name]
- model_path = CachedDownloader.download_to_cache(
- url, model_name, download=True, suffix=".pth", cache_dir=kornia_config.hub_onnx_dir
- )
- model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))["params_ema"], strict=True)
- model.eval()
- return SuperResolution(
- model,
- pre_processor=nn.Identity(),
- post_processor=OutputRangePostProcessor(min_val=0.0, max_val=1.0),
- name=model_name,
- )
|