feature_share.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from functools import lru_cache
  16. from typing import Any, Optional, Union
  17. from torch.nn import Module
  18. from torchmetrics.collections import MetricCollection
  19. from torchmetrics.metric import Metric
  20. from torchmetrics.utilities import rank_zero_warn
  21. __doctest_requires__ = {("FeatureShare",): ["torch_fidelity"]}
  22. class NetworkCache(Module):
  23. """Create a cached version of a network to be shared between metrics.
  24. Because the different metrics may invoke the same network multiple times, we can save time by caching the input-
  25. output pairs of the network.
  26. """
  27. def __init__(self, network: Module, max_size: int = 100) -> None:
  28. super().__init__()
  29. self.max_size = max_size
  30. self.network = network
  31. self.network.forward = lru_cache(maxsize=self.max_size)(network.forward)
  32. def forward(self, *args: Any, **kwargs: Any) -> Any:
  33. """Call the network with the given arguments."""
  34. return self.network(*args, **kwargs)
  35. class FeatureShare(MetricCollection):
  36. """Specialized metric collection that facilitates sharing features between metrics.
  37. Certain metrics rely on an underlying expensive neural network for feature extraction when computing the metric.
  38. This wrapper allows to share the feature extraction between multiple metrics, which can save a lot of time and
  39. memory. This is achieved by making a shared instance of the network between the metrics and secondly by caching
  40. the input-output pairs of the network, such the subsequent calls to the network with the same input will be much
  41. faster.
  42. Args:
  43. metrics: One of the following:
  44. * list or tuple (sequence): if metrics are passed in as a list or tuple, will use the metrics class name
  45. as key for output dict. Therefore, two metrics of the same class cannot be chained this way.
  46. * dict: if metrics are passed in as a dict, will use each key in the dict as key for output dict.
  47. Use this format if you want to chain together multiple of the same metric with different parameters.
  48. Note that the keys in the output dict will be sorted alphabetically.
  49. max_cache_size: maximum number of input-output pairs to cache per metric. By default, this is none which means
  50. that the cache will be set to the number of metrics in the collection meaning that all features will be
  51. cached and shared across all metrics per batch.
  52. Example::
  53. >>> import torch
  54. >>> from torchmetrics.wrappers import FeatureShare
  55. >>> from torchmetrics.image import FrechetInceptionDistance, KernelInceptionDistance
  56. >>> # initialize the metrics
  57. >>> fs = FeatureShare([FrechetInceptionDistance(), KernelInceptionDistance(subset_size=10, subsets=2)])
  58. >>> # update metric
  59. >>> input_tensor = torch.randint(255, (50, 3, 64, 64), dtype=torch.uint8, generator=torch.manual_seed(42))
  60. >>> fs.update(input_tensor, real=True)
  61. >>> input_tensor = torch.randint(255, (50, 3, 64, 64), dtype=torch.uint8, generator=torch.manual_seed(43))
  62. >>> fs.update(input_tensor, real=False)
  63. >>> # compute metric
  64. >>> fs.compute()
  65. {'FrechetInceptionDistance': tensor(13.5367), 'KernelInceptionDistance': (tensor(0.0003), tensor(0.0003))}
  66. """
  67. def __init__(
  68. self,
  69. metrics: Union[Metric, Sequence[Metric], dict[str, Metric]],
  70. max_cache_size: Optional[int] = None,
  71. ) -> None:
  72. # disable compute groups because the feature sharing is more custom
  73. super().__init__(metrics=metrics, compute_groups=False) # type: ignore
  74. if max_cache_size is None:
  75. max_cache_size = len(self)
  76. if not isinstance(max_cache_size, int):
  77. raise TypeError(f"max_cache_size should be an integer, but got {max_cache_size}")
  78. try:
  79. first_net = next(iter(self.values()))
  80. if not isinstance(first_net.feature_network, str):
  81. raise TypeError("The `feature_network` attribute must be a string.")
  82. network_to_share = getattr(first_net, first_net.feature_network)
  83. except AttributeError as err:
  84. raise AttributeError(
  85. "Tried to extract the network to share from the first metric, but it did not have a `feature_network`"
  86. " attribute. Please make sure that the metric has an attribute with that name,"
  87. " else it cannot be shared."
  88. ) from err
  89. except TypeError as err:
  90. raise TypeError("The `feature_network` attribute must be a string representing the network name.") from err
  91. cached_net = NetworkCache(network_to_share, max_size=max_cache_size)
  92. # set the cached network to all metrics
  93. for metric_name, metric in self.items():
  94. if not hasattr(metric, "feature_network"):
  95. raise AttributeError(
  96. "Tried to set the cached network to all metrics, but one of the metrics did not have a"
  97. " `feature_network` attribute. Please make sure that all metrics have a attribute with that name,"
  98. f" else it cannot be shared. Failed on metric {metric_name}."
  99. )
  100. if not isinstance(metric.feature_network, str):
  101. raise TypeError(f"Metric {metric_name}'s `feature_network` attribute must be a string.")
  102. # check if its the same network as the first metric
  103. if str(getattr(metric, metric.feature_network)) != str(network_to_share):
  104. rank_zero_warn(
  105. f"The network to share between the metrics is not the same for all metrics."
  106. f" Metric {metric_name} has a different network than the first metric."
  107. " This may lead to unexpected behavior.",
  108. UserWarning,
  109. )
  110. setattr(metric, metric.feature_network, cached_net)