| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from torch import nn
- from kornia.core import tensor
- from kornia.enhance.normalize import Normalize
- from kornia.filters.dexined import DexiNed
- from kornia.models.edge_detection.base import EdgeDetector
- from kornia.models.utils import ResizePostProcessor, ResizePreProcessor
- __all__ = ["DexiNedBuilder"]
- class DexiNedBuilder:
- """DexiNedBuilder is a class that builds a DexiNed model.
- .. code-block:: python
- images = kornia.utils.sample.get_sample_images()
- model = DexiNedBuilder.build()
- model.save(images)
- """
- @staticmethod
- def build(model_name: str = "dexined", pretrained: bool = True, image_size: int = 352) -> EdgeDetector:
- if model_name.lower() == "dexined":
- # Normalize then scale to [0, 255]
- norm = Normalize(mean=tensor([[0.485, 0.456, 0.406]]), std=tensor([[1.0 / 255.0] * 3]))
- model = nn.Sequential(norm, DexiNed(pretrained=pretrained), nn.Sigmoid())
- else:
- raise ValueError(f"Model {model_name} not found. Please choose from 'DexiNed'.")
- return EdgeDetector(
- model,
- ResizePreProcessor(image_size, image_size),
- ResizePostProcessor(),
- name="dexined",
- )
|