base.py 1.2 KB

1234567891011121314151617181920212223242526272829303132
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any
  15. from torchmetrics.metric import Metric
  16. class _ClassificationTaskWrapper(Metric):
  17. """Base class for wrapper metrics for classification tasks."""
  18. def update(self, *args: Any, **kwargs: Any) -> None:
  19. """Update metric state."""
  20. raise NotImplementedError(
  21. f"{self.__class__.__name__} metric does not have a global `update` method. Use the task specific metric."
  22. )
  23. def compute(self) -> None:
  24. """Compute metric."""
  25. raise NotImplementedError(
  26. f"{self.__class__.__name__} metric does not have a global `compute` method. Use the task specific metric."
  27. )